#!/usr/bin/env python

__author__ = "Randall J. Radmer"
__version__ = "1.0"
__doc__ = """Script to run simple water simulation."""

verbose=True

tryUsingGpu=True

temperature=300.0
stepSize=0.002
totalSteps=20000
stepsPerReport=2000
#randomSeed=None
randomSeed=1

import  os, sys, math, time

import simtk.utils.manageGPUs as gMan
(mm, usingCuda) = gMan.importBestOpenMMLib(tryUsingCuda=tryUsingGpu)

qO   = -0.8340
qH   =  0.4170
sigma   = 3.15061 / 10       #nm
epsilon = 0.6364             #kJ/mol
rOH  = 0.9572 / 10           #nm
aHOH = 104.52 * math.pi/180  #rad
ooRestraintK  =   1.0        #kJ/nm^2

xOffset=rOH*math.cos(aHOH/2.0)
yOffset=rOH*math.sin(aHOH/2.0)
rHH=2*xOffset


def loadCoordsHOH(fIn):
    """Read, parse input PDB file, if needed, add hydrogen atoms, save data"""
    lastAtomName=''
    pdbData=[]
    atomNum=0
    resNum=0
    for line in fIn:
        if line.find('HETATM')==0 or line.find('ATOM')==0:
            resName=line[17:20]
            if resName=='HOH' or resName=='WAT':
                atomName=line[12:16].strip()
                if atomName=='O' and lastAtomName=='O':
                    x1=pdbData[-1][-3]+xOffset
                    y1=pdbData[-1][-2]+yOffset
                    z1=pdbData[-1][-1]
                    x2=pdbData[-1][-3]+xOffset
                    y2=pdbData[-1][-2]-yOffset
                    z2=pdbData[-1][-1]
                    atomNum+=1
                    pdbData.append( (atomNum, 'H1', resName, resNum, x1, y1, z1) )
                    atomNum+=1
                    pdbData.append( (atomNum, 'H2', resName, resNum, x2, y2, z2) )
                atomNum+=1
                if atomName=='O':
                    resNum+=1
                x=float(line[30:38])/10.0
                y=float(line[38:46])/10.0
                z=float(line[46:54])/10.0
                pdbData.append( (atomNum, atomName, resName, resNum, x, y, z) )
                lastAtomName=atomName
    if lastAtomName=='O':
        x1=pdbData[-1][-3]+xOffset
        y1=pdbData[-1][-2]+yOffset
        z1=pdbData[-1][-1]
        x2=pdbData[-1][-3]+xOffset
        y2=pdbData[-1][-2]-yOffset
        z2=pdbData[-1][-1]
        atomNum+=1
        pdbData.append( (atomNum, 'H1', resName, resNum, x1, y1, z1) )
        atomNum+=1
        pdbData.append( (atomNum, 'H2', resName, resNum, x2, y2, z2) )
    return pdbData


def appendCoordsToPDB(fOut, step, eK, eP, pdbData, coords,
                      modelFrameNumber):
    """Append one coord set to open PDB file"""
    fOut.write("REMARK  step %6d: EK = %9.3f kj   EP = %9.3f kj\n"
                % (step, eK, eP))
    fOut.write("MODEL     %d\n" % modelFrameNumber)
    for ii in range(len(pdbData)):
        atomNum, atomName, resName, resNum, x, y, z = pdbData[ii]
        fOut.write("ATOM   %4d  %-2s  %3s  %4d    %8.3f%8.3f%8.3f\n"
                   % (atomNum, atomName, resName, resNum,
                      10*coords[ii][0], 10*coords[ii][1], 10*coords[ii][2]) )
    fOut.write("ENDMDL\n")


def main(verbose=False):
    """Main loop, loads pdb, builds OpenMM system, adds parameters, runs and dumps coords at PDB file"""
    time0=time.time()
    try:
        inFilename=sys.argv[1]
        outFilename=sys.argv[2]
    except IndexError:
        usageError()

    #Open in and out PDB files
    fIn=open(inFilename)
    fOut=open(outFilename, 'w')

    if usingCuda:
        sys.stdout.write("This simulation will be run using the GPU\n")
    else:
        sys.stdout.write("This simulation will be run without using a GPU\n")

    #load waters from input pdb file
    pdbData=loadCoordsHOH(fIn)
    fIn.close()

    #Open VMD connection
    vmd=None

    #Create water system
    system=mm.System()

    #Add water to system
    count=0
    initPositions=[]
    nb = mm.NonbondedForce()
    for atomNum, atomName, resName, resNum, x, y, z in pdbData:
        if atomName=='O':
            #Add an oxygen atom
            system.addParticle(16.0)
            nb.addParticle(qO, sigma, epsilon)
            lastOxygen=count
        else:
            #Add an hydrogen atom
            system.addParticle(1.0)
            nb.addParticle(qH, 1.0, 0.0)
            if count==lastOxygen+1:
                system.addConstraint(lastOxygen, count, rOH)        #O-H1
                nb.addException(lastOxygen, count, 0.0, 1.0, 0.0)   #O-H1
            elif count==lastOxygen+2:
                system.addConstraint(lastOxygen, count, rOH)        #O-H2
                nb.addException(lastOxygen, count, 0.0, 1.0, 0.0)   #O-H2
                system.addConstraint(count-1, count, rHH)           #H1-H2
                nb.addException(count-1, count, 0.0, 1.0, 0.0)      #H1-H2
            else:
                sys.stdout.write("ERROR: Too many hydrogens\n\tatomNum=%d, resNum=%d, resName=%s, atomName=%s:\n"
                                  % (atomNum, resNum, resName, atomName))
                sys.exit(1)
        initPositions.append( (x, y, z) )
        count+=1
    system.addForce(nb)
    nb.thisown=False

    #Add harmonic forces between each atom pair to keep them near each other
    restraints = mm.HarmonicBondForce()
    for atomNum1, atomName1, resName1, resNum1, x1, y1, z1 in pdbData:
        for atomNum2, atomName2, resName2, resNum2, x2, y2, z2 in pdbData:
            if atomNum1<atomNum2 and atomName1=="O" and atomName2=="O":
                restraints.addBond(atomNum1-1, atomNum2-1, 0.0, ooRestraintK)
    system.addForce(restraints)
    restraints.thisown=False

    #Add temp coupling to system
    thermostat = mm.AndersenThermostat(temperature, 1.0)
    if randomSeed:
        thermostat.setRandomNumberSeed(randomSeed)

    system.addForce(thermostat)
    thermostat.thisown=False

    #Select simple integrator
    integrator = mm.VerletIntegrator(stepSize)
    #Make a context for this system
    context=mm.OpenMMContext(system, integrator)
    #Staring config
    context.setPositions(initPositions)

    sys.stdout.write("Test will run to %.3f ps\n" % (stepSize*totalSteps))
    #Run simulation, and dump output
    step=0
    (simTime, eK0, eP0,
     coords, velList, forceList) = context.getState(getPositions=True,
                                                    getVelocities=False,
                                                    getForces=False)
    if verbose:
        sys.stdout.write("%8.2f ps: Potential Energy = %8.1f kj\n" % (stepSize*step, eP0))
        sys.stdout.flush()

    if vmd:
        coordsA=[]
        for x, y, z in coords:
            coordsA.append( (10*x, 10*y, 10*z) )
        vmd.sendEnergies(tstep=step, T=temperature, Etot=eK0+eP0, Epot=eP0)
        vmd.sendCoords(coordsA)
    appendCoordsToPDB(fOut, step, eK0, eP0, pdbData, coords, 0)
    count=0
    while step<totalSteps:
        #Do stepsPerReport steps
        integrator.step(stepsPerReport)
        step+=stepsPerReport
        (simTime, eK, eP,
         coords, velList, forceList) = context.getState(getPositions=True,
                                                        getVelocities=False,
                                                        getForces=False)
        if verbose:
            sys.stdout.write("%8.2f ps: Potential Energy = %8.1f kj\n" % (stepSize*step, eP))
            sys.stdout.flush()
        if vmd:
            coordsA=[]
            for x, y, z in coords:
                coordsA.append( (10*x, 10*y, 10*z) )
            vmd.sendEnergies(tstep=step, T=temperature, Etot=eK+eP, Epot=eP)
            vmd.sendCoords(coordsA)
        appendCoordsToPDB(fOut, step, eK, eP, pdbData, coords, count+1)
        count+=1
    fOut.write("END\n")
    fOut.close()
    if vmd:
        vmd.closeConnection()

    return time.time()-time0


def usageError():
    sys.stdout.write('usage: %s inFilename outFilename\n'
                     % os.path.basename(sys.argv[0]))
    sys.exit(1)

if __name__=='__main__':
    dTime=main(verbose=verbose)
    sys.stdout.write("Total run time = %.1f\n" % dTime)


