#!/usr/bin/env python

"""Script to run simple water simulation."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"

import  os, sys, math, time

import simtk.openmm.openmm as mm

import simtk.unit as unit
from simtk.unit import angstrom, nanometer
from simtk.unit import femtosecond, picosecond
from simtk.unit import kilojoule_per_mole
from simtk.unit import amu, elementary_charge
from simtk.unit import kelvin
from simtk.unit import degree


# Total simulation lenth, steps per report, and time step (with units)
totalSteps     = 2000
stepsPerReport = 200
stepSize       = 2.0 * femtosecond

# System temperature, with units
temperature = 300 * kelvin

# System pressure, with units
pressure = 1 * unit.bar

# Partial atomic charges for water
qO  = -0.8340 * elementary_charge
qH  =  0.4170 * elementary_charge

# Partial atomic charges for water
massO  = 16.0 * amu
massH  =  1.0 * amu

# Lennard-Jones parameters for oxygen-oxygen interactions
sigma   = 3.15061 * angstrom
epsilon = 0.6364  * kilojoule_per_mole

# Water bond and angle values
rOH   = 0.9572  * angstrom
aHOH  = 104.52  * degree

# Small attractive force constant to keep waters near earch other.
ooRestraintK  = 1.0 * kilojoule_per_mole/(nanometer*nanometer)

# Distance between the two H atoms in water
rHH = 2*rOH*unit.sin(aHOH/2.0)

verbose=True      # Write more Data
randomSeed=None   # For random number generator used by Andersen
                  # Thermostat.  A value of None lets the system
                  # pick the seed.  Set to an integer if you want
                  # the same trajectory, but understand that this
                  # does not *guarantee* the same trajectory.
                  # For example:  randomeSeed=1000


def main(verbose=False):
    """Main: reads a pdb file, builds OpenMM system,
       adds parameters, runs and dumps coords at PDB file"""
    time0=time.time()
    try:
        inFilename=sys.argv[1]
        outFilename=sys.argv[2]
    except IndexError:
        usageError()

    # Load waters from input pdb file
    fIn=open(inFilename)
    pdbData=loadCoordsHOH(fIn)
    fIn.close()

    # Open out PDB files
    fOut=open(outFilename, 'w')

    # Create water system, which will inlcude force field
    # and general system parameters
    system=mm.System()

    # Add water molecules to system
    # Note that no bond forces are used. Bond lenths are rigid
    count=0
    initPositions=[]*nanometer
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.PME)
    minX=None
    maxX=None
    minY=None
    maxY=None
    minZ=None
    maxZ=None
    for atomNum, atomName, resName, resNum, xyz in pdbData:
        if atomName=='O':
            #Add an oxygen atom
            system.addParticle(massO)
            nb.addParticle(qO, sigma, epsilon)
            lastOxygen=count
        elif atomName.startswith('H'):
            #Add an hydrogen atom
            system.addParticle(massH)
            nb.addParticle(qH, 1.0, 0.0)
            if count==lastOxygen+1:
                # For the last oxygen and hydrogen number 1:
                system.addConstraint(lastOxygen, count, rOH)        #O-H1
                # Exception: chargeProd=0.0, sigma=1.0, epsilon=0.0
                nb.addException(lastOxygen, count, 0.0, 1.0, 0.0)   #O-H1
            elif count==lastOxygen+2:
                # For the last oxygen and hydrogen number 2:
                system.addConstraint(lastOxygen, count, rOH)        #O-H2
                # Exception: chargeProd=0.0, sigma=1.0, epsilon=0.0
                nb.addException(lastOxygen, count, 0.0, 1.0, 0.0)   #O-H2

                # For hydrogen number 1 and hydrogen number 2
                system.addConstraint(count-1, count, rHH)           #H1-H2
                # Exception: chargeProd=0.0, sigma=1.0, epsilon=0.0
                nb.addException(count-1, count, 0.0, 1.0, 0.0)      #H1-H2
            else:
                s = "too many hydrogens:"
                s += " atomNum=%d, resNum=%d, resName=%s, atomName=%s" \
                    % (atomNum, resNum, resName, atomName)
                raise Exception(s)
                sys.exit(1)
        else:
            raise Exception("bad atom : %s" % atomName)
        initPositions.append(xyz)
        if minX is None or minX>xyz[0]:
            minX=xyz[0]
        if maxX is None or maxX<xyz[0]:
            maxX=xyz[0]
        if minY is None or minY>xyz[1]:
            minY=xyz[1]
        if maxY is None or maxY<xyz[1]:
            maxY=xyz[1]
        if minZ is None or minZ>xyz[2]:
            minZ=xyz[2]
        if maxZ is None or maxZ<xyz[2]:
            maxZ=xyz[2]
        count+=1

    zeroA = 0*angstrom
    oneA = 1*angstrom

    delX = maxX-minX
    delY = maxY-minY
    delZ = maxZ-minZ
    cutoff=min(delX, delY, delZ)/2
    boxX = delX+oneA
    boxY = delY+oneA
    boxZ = delZ+oneA

    nb.setCutoffDistance(cutoff)
    system.addForce(nb)
    system.setDefaultPeriodicBoxVectors((boxX,  zeroA, zeroA),
                                        (zeroA, boxY,  zeroA),
                                        (zeroA, zeroA, boxZ))
    # Create MonteCarloBarostat (constant pressure PBC)
    barostat = mm.MonteCarloBarostat(pressure, temperature, 10)
    system.addForce(barostat)

    #Add temp coupling to system
    thermostat = mm.AndersenThermostat(temperature, 1.0)
    system.addForce(thermostat)
    if randomSeed:
        thermostat.setRandomNumberSeed(randomSeed)

    #Select simple integrator
    integrator = mm.VerletIntegrator(stepSize)
    #Make a context for this system
    context=mm.Context(system, integrator)
    #Staring config
    context.setPositions(initPositions)
    if verbose:
        platform=context.getPlatform()
        sys.stdout.write('Using %s platform\n' % platform.getName())

    sys.stdout.write("Test will run to %s\n" % (stepSize*totalSteps))
    #Run simulation, and dump output
    step=0
    state = context.getState(getPositions=True,
                             getVelocities=False,
                             getForces=False,
                             getEnergy=True,
                             getParameters=False)
    simTime = state.getTime()
    eK = state.getKineticEnergy()
    eP = state.getPotentialEnergy()
    coords = state.getPositions()
    if verbose:
        pbvA, pbvB, pbvC = state.getPeriodicBoxVectors()
        sys.stdout.write("%s: Potential Energy = %s, PBV: %.3f A, %.3f A, %.3f A\n" %
                           ((0*picosecond).format("%7.1f"), eP.format("%7.3f"),
                            pbvA[0].value_in_unit(angstrom),
                            pbvB[1].value_in_unit(angstrom),
                            pbvC[2].value_in_unit(angstrom)))
        sys.stdout.flush()

    appendCoordsToPDB(fOut, step, eK, eP, pdbData, coords, 0)
    count=0
    while step<totalSteps:
        #Do stepsPerReport steps
        integrator.step(stepsPerReport)
        step+=stepsPerReport
        state = context.getState(getPositions=True,
                                 getVelocities=False,
                                 getForces=False,
                                 getEnergy=True,
                                 getParameters=False)
        simTime = state.getTime()
        eK = state.getKineticEnergy()
        eP = state.getPotentialEnergy()
        coords = state.getPositions()
        if verbose:
            pbvA, pbvB, pbvC = state.getPeriodicBoxVectors()
            sys.stdout.write("%s: Potential Energy = %s, PBV: %.3f A, %.3f A, %.3f A\n" %
                               (simTime.format("%7.1f"), eP.format("%7.3f"),
                                pbvA[0].value_in_unit(angstrom),
                                pbvB[1].value_in_unit(angstrom),
                                pbvC[2].value_in_unit(angstrom)))
            sys.stdout.flush()
        appendCoordsToPDB(fOut, step, eK, eP, pdbData, coords, count+1)
        count+=1
    fOut.write("END\n")
    fOut.close()

    return time.time()-time0


def appendCoordsToPDB(fOut, step, eK, eP, pdbData, coords,
                      modelFrameNumber):
    """Append one coord set to an open PDB file"""
    fOut.write("REMARK  step %6d: EK = %s   EP = %s\n"
                % (step, eK.format("%.3f"), eP.format("%.3f")))
    fOut.write("MODEL     %d\n" % modelFrameNumber)
    for ii in range(len(pdbData)):
        atomNum, atomName, resName, resNum, xyz = pdbData[ii]
        fOut.write("ATOM   %4d  %-2s  %3s  %4d    %8.3f%8.3f%8.3f\n"
                   % (atomNum, atomName, resName, resNum,
                      coords[ii][0].value_in_unit(angstrom),
                      coords[ii][1].value_in_unit(angstrom),
                      coords[ii][2].value_in_unit(angstrom)))
    fOut.write("ENDMDL\n")


def loadCoordsHOH(fIn):
    """Parse input PDB file"""
    pdbData=[]
    atomNum=0
    resNum=0
    for line in fIn:
        if line.find('HETATM')==0 or line.find('ATOM')==0:
            resName=line[17:20]
            if resName=='HOH' or resName=='WAT':
                atomNum+=1
                atomName=line[12:16].strip()
                if atomName=='O':
                    resNum+=1
                try:
                    if atomName==pdbData[-1][1] or atomName==pdbData[-2][1]:
                        raise Exception("bad water molecule near %s..." % line[:27])
                except IndexError:
                    pass
                x=float(line[30:38])
                y=float(line[38:46])
                z=float(line[46:54])
                pdbData.append( (atomNum, atomName, resName, resNum,
                                 ((x, y, z) * angstrom) ) )
    return pdbData


def usageError():
    sys.stdout.write('usage: %s inFilename outFilename\n'
                     % os.path.basename(sys.argv[0]))
    sys.exit(1)

if __name__=='__main__':
    dTime=main(verbose=verbose)
    sys.stdout.write("Total run time = %.1f Sec\n" % dTime)


