from Physical import *
from Forces import *
from Propagator import *
from IO import *
from ForceField import *

import math
import numpy
import string


def CartesianToModespace(Q_S_M,sqmH,xa,xb):
    xdf = xb-xa
    #print "xdf shape ",xdf.shape

    #### temp_m = M^{1/2}(x-x0)
    temp_m = numpy.dot(sqmH,xdf)

    #print "temp_m size ",temp_m.shape


    #### c_s = Q_S^{T}M^{1/2}(x-x0)
    c_s = numpy.dot(Q_S_M.transpose(),temp_m)

    #print "c_s size ",c_s.shape

    return c_s

def sqrtmass(rd,cd,md):
    mmat = numpy.zeros((rd,cd))
    for i in range(0,rd):
        for j in range(0,cd):
            if (i==j):
                mmat[i,j] = numpy.sqrt(md[i])
                #print "mass ",md[i],", sqrt mass ",numpy.sqrt(md[i])
            else:
                mmat[i,j] = 0
    return mmat

    

def simplePropagateFuction(myPropagatorClass, vsteps, scheme="Leapfrog", steps=0, cyclelength=-1, dt=0.1, forcefield=[], params={}):
    myPropagatorClass.myTimestep = dt
    chain = ()
    if (cyclelength != -1):  # MTS
        if (str(type(cyclelength))[7:11] == 'list'): # LIST, MANY LEVELS
            levels = len(cyclelength) + 1
            outertime = cyclelength[0]
        else: # ONE CYCLELENGTH = 2 LEVELS OF PROPAGATION
            levels = 2
            outertime = cyclelength

        if (str(type(scheme))[7:11] == 'list'): # LIST, MANY LEVELS
            outerscheme = scheme[0]
            print outerscheme
        else: # ONE CYCLELENGTH = 2 LEVELS OF PROPAGATION
            outerscheme = scheme


        # THE NUMBER OF FORCEFIELDS PROVIDED MUST EQUAL THE NUMBER
        # OF PROPAGATION LEVELS
        if (len(forcefield) != levels):
            print "[MDL] Error in propagate(): ", levels, " levels of propagation with ", len(forcefield), " force fields."
        outerforcefield = forcefield[0]

        if (str(type(scheme))[7:11] != 'list'):
            chain += (params,)
        else:
            if (params.has_key(outerscheme)):
                chain += (params[outerscheme],)
            else:
                chain += ({},)

        for i in range(1, levels):
            if (str(type(scheme))[7:11] == 'list' and i < len(scheme)):
                chain += (scheme[i],)
            if (str(type(cyclelength))[7:11] == 'list' and i < len(cyclelength)):
                chain += (cyclelength[i],)
            else:
                chain += (dt,)
            chain += (forcefield[i],)
            if params.has_key(scheme[i]):
                chain += (params[scheme[i]],)
            else:
                chain += ({},)

        #print chain

    else: #STS
        outertime = dt
        outerscheme = scheme
        outerforcefield = forcefield
        chain += (params,)

    if (myPropagatorClass.forces.dirty()):
        myPropagatorClass.forces.build()

    if (myPropagatorClass.io.dirty):
        myPropagatorClass.io.build()

    print outerscheme
    if (propFactory.getType(outerscheme) == "method"):
        # Calculate the forces, store them in force.
        print 'Calculate initial force for outer scheme '
        outerforcefield.calculateForces(myPropagatorClass.phys, myPropagatorClass.forces)
        print 'Initialize : Updating center of mass and momenta '
        myPropagatorClass.phys.updateCOM_Momenta()
        print 'Initialize : Run IO module to generate initial output '
        myPropagatorClass.io.run(myPropagatorClass.phys, myPropagatorClass.forces, 0, outertime)
        myPropagatorClass.io.myProp = myPropagatorClass
    else:
        #print "Propagator is an object"
        #sys.exit(0)
        #return
        #print chain
        ### 1. create the propagator object
        setPropagator(myPropagatorClass, myPropagatorClass.phys, myPropagatorClass.forces, propFactory.applyModifiers(propFactory.create(outerscheme, outertime, outerforcefield, *chain), outerscheme))

        ## 2. Updating the \kappa values (from eigenvalues/rayleigh quotients)
        myPropagatorClass.myPropagator.SetEigvals(eigvals)

        myPropagatorClass.myPropagator.SetRefC(c_s)
        myPropagatorClass.myPropagator.SetRefPos(myPhys.positions)

        shake = False
        if (params.has_key('shake') and params['shake'] == 'on'):
            shake = True
            shakeMod = myPropagatorClass.myPropagator.createShakeModifier(0.00001, 30)
            myPropagatorClass.myPropagator.adoptPostDriftOrNextModifier(shakeMod)
        rattle = False
        if (params.has_key('rattle') and params['rattle'] == 'on'):
            rattle = True
            rattleMod = myPropagatorClass.myPropagator.createRattleModifier(0.02, 30)
            myPropagatorClass.myPropagator.adoptPostDriftOrNextModifier(rattleMod)

        ### 3. execute propagator for 1 step.
        print "Before executing propagator for ",vsteps," steps"
        myPropagatorClass.myPropagator.SetRefX()
        executePropagator(myPropagatorClass, myPropagatorClass.phys, myPropagatorClass.forces, myPropagatorClass.io, vsteps)
        print "Done propagator"


def EigvalReader(eigvalfilename):
    myeigfile = open(eigvalfilename,"r")
    evals = numpy.zeros((cv,1))
    i= 0
    for line in myeigfile:
        evals[i] = float(line)
        i = i+1
        if(i == cv):
            break
    myeigfile.close()
    for ii in range(0,cv):
        evals[ii] = 10*evals[ii]
        if (evals[ii] < 1):
            evals[ii] = 1

    return evals

def GetQSMatrix(c_v,phys,io):
    xa = phys.positions
    Q_S = io.getTextEigenvectorsData(phys)
    #print Q_S.shape
    #if(isinstance(Q_S,numpy.ndarray)):
    #    print "Q_S is ndarray"
    #elif(isinstance(Q_S,numpy.matrix)):
    #    print "Q_S is matrix"
    #else:
    #    print "Not sure what Q_S is"
 
    #print Q_S.shape
    #print Q_S.shape[0]
    #print Q_S.shape[1]
 
    #Q_S_M = convert_to_numpy_matrix(Q_S,xa.size)
    Q_S_M = Q_S.reshape(xa.size,(Q_S.shape[0]/xa.size),order='F').copy()
    #print Q_S_M.shape[0]
    #print Q_S_M.shape[1]
 
    if (c_v == -1):
        c_v = Q_S_M.size/Q_S_M[0].size
 
    #print c_v
 
    ######### I need to take only those column vectors which are common ###############
    Q_S_M = Q_S_M[:,0:c_v]
 
    #print Q_S_M.shape[0]
    #print Q_S_M.shape[1]
    return Q_S_M



#no of common vectors
cv = -1
cv = int(sys.argv[1])

eigvecfilename = sys.argv[2]

eigvalfilename = sys.argv[3]

myPhys = Physical()
myForces = Forces()
myIO = IO()
myProp = Propagator(myPhys,myForces,myIO)
ef = 'alan.energies.out.0'
dcdf='alan.dcd'
os.system("rm -f alan.energies.out.0")
myIO.files = {'energies':(ef,1),'dcdtraj':(dcdf,1)}

# initialization of the system

myIO.readPDBPos(myPhys,"data/minC7eq_a3.pdb")
myIO.readPSF(myPhys,"data/alan_mineq.psf")
myIO.readPAR(myPhys,"data/par_all27_prot_lipid.inp")

print "Before eig reader"
#myIO.readTextEigenvectors(myPhys,eigvecfilename)
myIO.readEigenvectors(myPhys,eigvecfilename)
print "After eig reader"

myPhys.bc = "Vacuum"
myPhys.temperature = 300
myPhys.exclude = "scaled1-4"
myPhys.seed = 1234

totalsteps = 10000

eigvals = EigvalReader(eigvalfilename)

Q_S_M = GetQSMatrix(cv,myPhys,myIO)
##Get the matrix M^{1/2}
sqmH = sqrtmass(len(myPhys.masses),len(myPhys.masses),myPhys.masses)

physB = Physical()
myIO.readPDBPos(physB,"data/armin_a3.pdb")
c_s = CartesianToModespace(Q_S_M,sqmH,myPhys.positions,myPhys.positions)

for i in range(len(c_s)):
    print c_s[i]

#sys.exit(0)


ff1 = myForces.makeForceField(myPhys)
ff1.bondedForces("badi")
ff1.nonbondedForces("le")

ff2 = myForces.makeForceField(myPhys)
ff2.bondedForces("badi")
ff2.nonbondedForces("le")

## In order to test it, I need to calculate C values at the start
## and update with the integrator.

#myIO.screen = 1

#simplePropagateFuction(myProp,10,scheme=["LangevinImpulse"],
#    steps=1, cyclelength=1, dt=1.0, forcefield=[ff1, ff2],
#    params={'NormModeSampling':{'fixmodes':66-cv,'gamma':91,'fdof':6},
#    'NormModeSamplingMin':{'avModeMass':3.0}})


simplePropagateFuction(myProp,totalsteps,scheme=["NormModeSampling", "NormModeSamplingMin"],
    steps=1, cyclelength=1, dt=1.0, forcefield=[ff1, ff2],
    params={'NormModeSampling':{'fixmodes':66-cv,'gamma':91,'fdof':6},
    'NormModeSamplingMin':{'avModeMass':3.0}})




sys.exit(0)

