# Function for estimation of the force coefficients that generated 
# a Brownian dynamics trajectory. The Brownian dynamics is (effectively):
#
# D**-1 ( dx_i - F(x_i) ) = R ~ N( 0,1 )
#
# and the force is assumed to be
#
# F(x) = Sum_{k=0}^{r-1} a_k x^k 
#
# where r is the order (must be even).
#
# Daniel L. Ensign
# Prof. Vijay S. Pande
# Department of Chemistry 
# Stanford University
# 22 October 2009

from sympy import Integer, Symbol, vectorize
from mpmath import gamma,  pi, mpf
from numpy import array
from BDDiffusionDistribution import BDDiffusionDistribution
from BDTrajectory import BDTrajectory
from BDObject import BDObject

Sympy0 = Integer(0)

# Set up one of these objects in order to run the Gaussian algorithm
# on a BDTrajectory object at order 'order.'
#
# TO DO: Change this so that trajectory becomes trajectoryList; use
# the data in all trajectories when computing the statistics in 
# BDGaussianAlgorithm._set_order. Could easily be done with a derived
# class. 
#
class BDGaussianAlgorithm( BDObject ) :

	def __init__( self, trajectory, order, verbose = False ):
		BDObject.__init__( self, verbose )
		self.trajectory = trajectory
		self.order = order

		self.a = []
		self.f = Sympy0
		self.x = Symbol( "x" )
		self.constructForceCoeffList()
		self.Ddistribution = None
		self.poly = None
		self.D = Symbol( "D", real = True, positive = True )

	def _get_order( self ):
		return self._order 
	def _set_order( self, value ):
		self._order = value
		if self.trajectory.order < self.order :
			# This will compute additional necessary statistics
			# through BDTrajectory 'set' methods.
			self._vmesg( "Changing %s order from %d to %d" % (self, self.trajectory.order, self.order ) )
			self.trajectory.order = self.order

			self.a = []
			self.f = Sympy0
			self.constructForceCoeffList()
	
		self.sdx2 = self.trajectory.sdx2 	# scalar
		self.sx = self.trajectory.sx 		# list
		self.sxdx = self.trajectory.sxdx 	# list
		self._vmesg( "Sum of square displacements: %f" % self.sdx2 )

	order = property( _get_order, _set_order )			

	def constructForceCoeffList( self ):
		for k in range( self.order ):
			self.a.append( Symbol( "a" + str(k) ) )
			self.f += self.a[-1]*self.x**k

	def constructExponentialArgument( self ):
		#self.constructExponentialArgument_long()
		self.constructExponentialArgument_fast()

	def constructExponentialArgument_long( self ) :
		a = self.a 
		order = self.order 
		deltax = self.trajectory.deltax
		xtraj = self.trajectory.x
		
		poly = Sympy0 
		for i in range( len( xtraj ) ):
			fi = Sympy0
			for n in range( order ): fi+= a[n]*xtraj[i]**n
			poly += ( deltax[i] - fi )**2
		self.poly = (poly/self.D).expand()
		self._vmesg( "Polynomial: %s\n" % self.poly )

	def constructExponentialArgument_fast( self ):
		sdx2 = self.sdx2
		sx = self.sx
		sxdx = self.sxdx
		a = self.a 
		poly = Sympy0 + sdx2
		for i in range( self.order ):
			poly -= 2*sxdx[ i ]*a[i]
		for i in range( self.order ):
			for j in range( self.order ):
				poly += sx[ i+j ]*a[i]*a[j]

		self.poly = (poly/self.D).expand()
		self._vmesg( "Polynomial: %s\n" % self.poly )

	def estimateCoefficients( self ):
		terms = self.poly.args
		estimates = {}
		precisions = {}
		
		for ak in self.a :
			constant = 0 
			linear = 0
			quadratic = 0
		
			for term in terms :
				if term.has( ak**2 ): 
					quadratic += term/ak**2
				elif term.has( ak ): linear += term/ak
				else : constant += term

			precisions[ ak ] = quadratic
			estimates[ ak ] = -( linear/(2*quadratic) ).expand()
			leftovers = ( constant - (linear**2)/(4*quadratic) ).expand()
			terms = leftovers.args
			self._vmesg( "term %s leftovers %s" % ( ak, leftovers ) )

		# now solve for the mu's. The last one is always (?) determined.
		for k in range( self.order-2, -1, -1 ): # from a[k-1] to a[0]
			for j in range( self.order-1, k, -1 ): # only need a[r], ..., a[k+1]
				estimates[ self.a[k] ] = estimates[ self.a[k] ].subs( self.a[j], estimates[ self.a[j] ] )

		self.estimates = estimates
		self.precisions = precisions
		self.s2 = leftovers 		# needed to get posterior distribution of D

	def constructDiffusionDistribution( self ):
		ndof = len( self.trajectory ) - self.order
		prec = self.s2*self.D  # + self.sdx2
		self.Ddistribution = BDDiffusionDistribution( prec, ndof, verbose = self.verbose )
		self.meanD = self.Ddistribution.mean.evalf()

	def solveForMeans( self ):
		for symbol in self.a :
			self.estimates[ symbol ] = self.estimates[ symbol ].evalf( subs={self.D: self.meanD} )		
	
	def calculatePosteriorWeight( self ) :
		# careful here - both s2 and precisions contain factors of 1/D
		# we remove them to try to be consistent with the text
		s2 = self.s2*self.D
		M = len( self.trajectory )
		R = self.order
	
		Aprod = 1
		musum = 0
		for symbol in self.a :
			Aprod *= self.precisions[ symbol ] * self.D
			musum += self.estimates[symbol]**2

		f1 = .5*pi**(-0.5*M)
		f2 = mpf(Aprod)**-0.5
		f3 = mpf(musum)**(-0.5*R)
		f4 = mpf(s2)**(-0.5*M+0.5*R)
		f5 = gamma(0.5*M - 0.5*R)
		f6 = gamma( 0.5*R) 
		self.posteriorWeight = f1*f2*f3*f4*f5*f6

	def do( self ) :
		self.constructExponentialArgument()
		self.estimateCoefficients()
		self.constructDiffusionDistribution()

		# now that we have an estimate for D, we can get the coeffs
		self.solveForMeans()

		self.calculatePosteriorWeight()
