
# TODO
# set up __mul__ method for combining these guys

from BDObject import BDObject
from mpmath import sqrt, log
from sympy import Symbol, exp, gamma

# s2 = precision
# M = number of "degrees of freedom"
class BDDiffusionDistribution( BDObject ):
	
	def __init__( self, s2, M, verbose = False ) :
		BDObject.__init__( self, verbose )

		expr = "Trying to set up distribution with s2 = " +str(s2) + " and M = " + str(M) 
		self._vmesg( expr )

		self.s2 = s2
		self.M = M
		self.D = Symbol( "D" )
		
	def __str__( self ):
		halfM = .5*self.M
		halfs2 = .5*self.s2

		factor1 = self.D**(-1-halfM)
		factor2 = halfs2**halfM
		factor3 = exp( -halfs2/self.D )
		factor4 = 1.0/gamma( halfM ).evalf()
		expression = (factor1*factor2*factor3*factor4).expand() 
		return "p( D | X ) = %s" % str( expression )

	#def __str__( self ): return "p( D | X ) = %s" % str( self.__repr__() )

	#def __mul__( self, other ): return 1 

	def _get_s2( self ):
		return self._s2
	def _set_s2( self, value ):
		if value > 0 :
			self._s2 = value
		else :
			raise ValueError( "Can't set s2 to negative value % f" % value )
	s2 = property( _get_s2, _set_s2 )

	def _get_M( self ):
		return self._M
	def _set_M( self, value ):
		if value != int( value ):
			raise ValueError( "Trajectory length M must be an integer (got %f)" % value )
		if value <= 0 :
			raise ValueError( "Trajectory length M cannot be 0 or negative (got %d)" % value )
		self._M = int(value)
	M = property( _get_M, _set_M )

	# return the kth moment
	def moment( self, k ):
		s2 = self.s2
		M = self.M
		if self.M < 2*k : return None
		else : 
			t1 = log( (s2/2.)**k )
			t2 = log( gamma( .5*M - k ).evalf() )
			t3 = log( gamma( .5*M ).evalf() )
			sum = t1+t2-t3
			return exp(sum) # (s2/2.)**k * gamma( .5*M - k ) / gamma( .5*M )

	# central moments and such
	def _get_mean( self ):return self.moment(1).evalf()
	mean = property( _get_mean )
	def _get_var( self ) : return ( self.moment(2) - self.moment(1)**2 ).evalf()
	var = property( _get_var )
	def _get_std (self ): return sqrt( self.var.evalf() )
	std = property( _get_std )

	def __mul__( self, other ):
		pass

	
