import os, sys
import matplotlib.pyplot as pp
import matplotlib.cm as cm
import numpy as np
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../lib'))
from metrics import DWMetric
from Propagator import LDPropagator
import inspect
from msmbuilder.Serializer import Serializer
from euclid.drift import drift
from euclid.utils import fft_acf, format_block

trajlength = 200000
kT = 0.75

def main():
    directory = sys.argv[1]
    #diffusion_consts = np.array([1.0, float(sys.argv[2])])
    scaling = np.array([100.0, 0.0])
    diffusion_consts = ''
    #scaling = np.array([1.0, float(sys.argv[2])])
    
    # make paths
    if not os.path.exists(directory):
        os.makedirs(directory)
    descr_path = os.path.join(directory, 'description.txt')
    acf_path = os.path.join(directory, 'acf.png')
    traj_path = os.path.join(directory, 'traj.png')
    d0_path =  os.path.join(directory, 'drift_x.h5')
    d1_path =  os.path.join(directory, 'drift_y.h5')
    drift_plot_path = os.path.join(directory, 'drift_plot.png')
    def ensure_not_exists(path):
        if os.path.exists(path):
            raise ValueError("%s exists" % path)
    map(ensure_not_exists, [descr_path, acf_path, traj_path, d0_path, d1_path])
    
    def dV(x):
        #return scaling * (5 * np.array([np.cos(5*x[0]) + np.cos(5*x[1])]))
        return scaling * 5 * np.array([np.cos(5*x[0]), 0.0])
    ld = LDPropagator(2, dV, np.array([[-5,5],[-5,5]]), kT)
    ld.run(trajlength)
    #class LD:
    #    trajectory = ''
    #ld = LD()
    #ld.trajectory = bm(2, trajlength, diffusion_consts)
    
    
    taus = [1,5,10,50,100,200,350,500,700,1000,5000]
    m0, m1 = DWMetric(0), DWMetric(1)
    d0 = drift(ld.trajectory, taus, m0)
    d1 = drift(ld.trajectory, taus, m1)
    
    # compute means, std
    d0_m = np.array([np.mean(np.ma.masked_less(d0, 0)[i,:]) for i in range(len(taus))])
    d0_std = np.array([np.std(np.ma.masked_less(d0, 0)[i,:]) for i in range(len(taus))])
    d1_m = np.array([np.mean(np.ma.masked_less(d1, 0)[i,:]) for i in range(len(taus))])
    d1_std = np.array([np.std(np.ma.masked_less(d1, 0)[i,:]) for i in range(len(taus))])
    
    # save drifts to disk
    e0 = {'Data': d0, 'Taus': taus, 'type': 'Dimension 0 of 1'}
    e1 = {'Data': d1, 'Taus': taus, 'type': 'Dimension 1 of 1'}
    Serializer(e0).SaveToHDF(d0_path)
    Serializer(e1).SaveToHDF(d1_path)
    
    with open(descr_path,'w') as f:
        print >> f, format_block("""
        trajlength: %s
        taus: %s
        UNUSED dV:%s
        diffusion_consts: %s
        UNUSED kT: %s
        
        drift x means: %s
        drift x std: %s
        drift y means: %s
        drift y std: %s
        """ % (trajlength, str(taus), inspect.getsource(dV), str(diffusion_consts),
               kT, d0_m, d0_std, d1_m, d1_std))
    
    # make plots
    xyzlist = ld.trajectory['XYZList']
    pp.figure()
    pp.title('Trajectory in 2D: %s' % str(diffusion_consts))
    colormap = cm.get_cmap('spectral')
    length = np.arange(len(xyzlist))
    pp.scatter(xyzlist[:,0], xyzlist[:,1], s=1, c=length, cmap=colormap, edgecolors='none')
    pp.colorbar()
    pp.savefig(traj_path)
    
    # autocorrelation plot
    pp.figure()
    pp.title('Autocorrelation: %s' % str(diffusion_consts))
    pp.semilogx(fft_acf(xyzlist[:,0]), label='x')
    pp.semilogx(fft_acf(xyzlist[:,1]), label='y')
    pp.legend()
    pp.ylabel('autocorrelation')
    pp.xlabel('time (steps)')
    pp.savefig(acf_path)
    
    # mean and std deviation plot
    pp.figure()
    pp.subplot(211)
    pp.title('drift: %s' % str(diffusion_consts))
    pp.plot(taus, d0_m, '-x', label='mean drift x')
    pp.plot(taus, d1_m, '-x', label='mean drift y')
    pp.xlim(0, 1.05*max(taus))
    pp.xlabel('delta t')
    pp.ylabel('mean drift')
    pp.legend(loc=4)
    pp.subplot(212)
    pp.plot(taus, d0_m, '-x', label='std drift x')
    pp.plot(taus, d1_m, '-x',label='std drift y')
    pp.xlim(0, 1.05*max(taus))
    pp.ylabel('std drift')
    pp.xlabel('delta t')
    pp.legend(loc=4)
    pp.savefig(drift_plot_path)
    
def bm(dims, length, diffusion_consts):
    rand = diffusion_consts * np.sqrt(1.0 / length) * np.random.randn(length, dims)
    return {'XYZList': np.cumsum(rand, axis=0)}


if __name__ == '__main__':
    main()