#!/usr/bin/env python
import pdb
import os
import re
from pylab import sys
from numpy import *
from numpy.random import *
import timeseries

k = int(sys.argv[1])
nmax = int(sys.argv[2])

#kharm = ones(nwin,float)
#kharm *= 15
kharm = [20,2,20,2,20,2,20,2]
nwin = len(kharm)

perio = 2*pi
chi_min = -perio/2
chi_max = perio/2
delta = (chi_max-chi_min)/nwin
kloc = arange(chi_min,chi_max,delta)

print "# %12.5f%12.5f" % (kloc[k],kharm[k])

############ MC Initialization ############
betat = 1
rseed = int(sys.argv[3])
seed(rseed)

################## Function Definitions ########################

def harm(K,X):
   if (K==-1):
      kx0=0
      kh=0
   else:  
      kx0 = kloc[K]
      kh = kharm[K]
   dx = fabs(X-kx0)
   wrap = (dx > perio/2)
   dx = dx -perio*wrap
   result = 0.5*betat*kh*(dx)**2
   return result

def func(X):
   xf = X - perio/4
   result = 2*(3+cos(xf)+cos(2*xf)+cos(4*xf)) 
   return result

def en(K,X):
   return harm(K,X) + func(X)

def uprob(betat,K,X):
   return exp(-betat*en(K,X))

def NIntegrate(x0,x1,nint,k,betat):

    delta = (x1-x0)/nint
    sum = 0;
    #integrate by simpsons
    for i in range(nint+1):
        if ((i == 0) or (i == nint)):
            fac = 1.0/3.0
        else:
            if (i%2==0):
               fac = 2.0/3.0
            else:
               fac = 4.0/3.0
	x   = x0 + i*delta;
        sum  += fac*uprob(betat,k,x)*delta;
        #print "%12.5f%12.5f%12.5f%12.5f" % (x,en(k,x),harm(k,x),func(x))
    return sum

############# Generate data points from harmonic wells #################################

xa = zeros(nmax,float)
ea = zeros(nmax,float)
ua = zeros(nmax,float)

Zp = NIntegrate(chi_min,chi_max,1000,k,betat)
Zp0 = NIntegrate(chi_min,chi_max,1000,-1,betat)
print "#Zp =", -log(Zp/Zp0)
# determine approximate max of probability for rejection sampling
xl = arange(-pi,pi,0.01)
maxf = 1.1*max(uprob(betat,k,xl)/Zp)

i = 0
# rejection sampling with g(x) = flat.
while (i < nmax):
   rt = random(2)
   x = chi_min + perio*rt[0]
   u = rt[1]
   p = (uprob(betat,k,x)/Zp)/maxf
   if (u < p):
       xa[i] = x
       i+=1
ea = en(k,xa)
ua = ea - harm(k,xa)    

gu = timeseries.statisticalInefficiency(ua,ua)
ge = timeseries.statisticalInefficiency(ea,ea)
gx = timeseries.statisticalInefficiency(xa,xa)

print "#correlations in unbiased energies:",gu
print "#correlations in energies:",ge
print "#correlations in coordinates:",gx

for i in range(nmax):
  print "%-d\t%.5f\t%.5f\t%.5f" % (i,xa[i],ea[i],ua[i])


