#!/usr/bin/python

import sys
def usage() :
	print "usage: %s <traj. file> <x0> <dt> <kT> <max order>" % sys.argv[0]
	print "<kT> = thermal temperature"
	print "<dt> = simulation time step"
	sys.exit()

try:
	trajectory = sys.argv[1]
	x0 = float( sys.argv[2] )
	dt = float( sys.argv[3] )
	kT = float( sys.argv[4] )
	maxOrder = int( sys.argv[5] )
except:
	usage()

print "Using trajectory file", trajectory
print "x0", x0
print "dt", dt
print "kT", kT
print "up to order =", maxOrder

from BDTrajectory import *
from Timer import Timer
from numpy import array,loadtxt
from sympy import Symbol, exp, pprint, simplify, log, diff, solve, Rational
from mpmath import gamma, pi, quad, inf, mp

timer = Timer()

# 1. LOAD TRAJECTORY
print "loading trajectory"
trajectory = loadtxt( trajectory, "float" )
print timer

# 2. SCALE TRAJECTORY
scaling = (2.*dt)**-.5
print "scaling trajectory and initial point by *", scaling
trajectory = trajectory*scaling
x0 = x0*scaling
print timer

# 3. COMPUTE Delta-x TRAJECTORY
print "computing delta x"
trajectory = BDTrajectory( trajectory, x0, maxOrder, verbose = True )
deltax = array(trajectory.deltax)
print timer

# 4. BUILD FORCE FUNCTION (FOR EACH ORDER)
results = []
D = Symbol('D', positive=True)
x = Symbol('x')
acoeffs = [ Symbol( 'a%d' % n ) for n in range(maxOrder) ]
for order in range(2,maxOrder+1,2):
	ForceFunction = array([ acoeffs[n] * x**n * D**-Rational(n,2) for n in range(order) ]).sum()
	print "order",order, ": F(x) =", ForceFunction, timer

	# this loop is really slow
	Q = 0
	dscale = D**-Rational(1,2)
	for i in range(len(trajectory)) :
		Q += ( dscale * deltax[i] - ForceFunction.subs( {x:trajectory.x[i]} ) )**2 
		if i % 1000 == 0 : print i
	Q = Q.expand()
	print "Q computed:", timer

	# this should be wrapped into a function
	terms = Q.args
	estimates = {}
	precisions = {}
	
	for an in acoeffs[:order] :
	    constant = 0
	    linear = 0
	    quadratic = 0
	
	    for term in terms :
	        if term.has( an**2 ):
	            quadratic += term/an**2
	        elif term.has( an ): linear += term/an
	        else : constant += term
	
	    precisions[ an ] = quadratic
	    estimates[ an ] = -( linear/(2*quadratic) ).expand()
	    leftovers = ( constant - (linear**2)/(4.0*quadratic) ).expand()
	    terms = leftovers.args
	    #print "term %s leftovers %s" % ( an, leftovers ), timer
	
	# now solve for the mu's. The last one is always (?) determined.
	for k in range( order-2, -1, -1 ): # from a[k-1] to a[0]
	    for j in range( order-1, k, -1 ): # only need a[r], ..., a[k+1]
	        estimates[ acoeffs[k] ] = estimates[ acoeffs[k] ].subs( acoeffs[j], estimates[ acoeffs[j] ] )
	
	for n in range(order):
		print acoeffs[n], ".=", estimates[acoeffs[n]], ", precision =", precisions[acoeffs[n]]
	print "leftovers =", leftovers	
	
	# COMPUTE THE LIKELIHOOD OF D
	estimates = array( [ estimates[ acoeffs[n] ] for n in range(order) ] )
	precisions = array( [ precisions[ acoeffs[n] ] for n in range(order) ] )

	destimatesTable=array( [D**(-.5*n+.5) for n in range(order)] )
	dprecisionsTable=array( [D**n for n in range(order)] )
	precisions=precisions*dprecisionsTable

	print "estimates", estimates
	print "precisions", precisions

	M = len( trajectory )
	halfR= order/2 # assuming this is even!
	halfM = Rational(M,2)
	part1 = 2.0**( -1 - halfM + halfR  )
	part2 = D**( -halfM - Rational(order,4) + Rational( order**2, 4 ) )
	part3 = exp( -.5*leftovers )
	part4 = gamma(halfR)
	part5 = pi**-halfM
	part6 = ( precisions.prod() )**-.5
	part7 = ( (estimates**2).sum() )**-halfR
	likelihood = (part1*part2*part3*part4*part5*part6*part7).evalf()
	print likelihood
	print "p(X|D,H_R) ="
	pprint( likelihood	)

	# COMPUTE THE FUNCTION MAXIMUM
	# strategy: a slightly rearranged version of the likelihood should be 
	# easier for sympy to solve
	if order == 2 :
		estimates = estimates*destimatesTable
		leftovers = leftovers*D
		nvalues = array( range( order ) )
		epart1 =  2*D*M - D*order + 2*D*nvalues*order - D*order**2 - 2*leftovers  
		expression = Rational(1,4) * D**nvalues * epart1 * estimates**2 
		expression = expression.sum()
		solution = solve( expression, D )
		maximum = [ value for value in solution if value>0 and value.is_real ][0]
		print "found maximum likelihood at", maximum
	else :
		print "integrating around", maximum

	# integrate
	print "normalizing ..."
	difference = 1
	tol = 0.00001
	posterior = likelihood
	while difference > tol :
		print "iter!",
		postwt = lambda value: posterior.subs( {D:value} ).evalf()
		norm = quad( postwt, [0,maximum,inf] )
		posterior = posterior/norm
		check = quad( postwt, [0,maximum,inf] )
		print "checking normalization, 1 ==", check
		difference = abs(check-1)
	print "normalization within %f of 1" % tol
	modelweight = likelihood/posterior

	print "p(D|X,H_R) ="
	pprint( posterior )
	
	print "model weight =", modelweight	
	results.append( ( order, modelweight ) )

print "order\tweight\trelative"
relweight = results[0][1]
for order, weight in results :
	print order, weight, weight/relweight 
