#!/bin/env python

__author__ = "Randall J. Radmer"
__version__ = "1.0"
__doc__ = """Script to run simple argon simulation."""

import  os, sys, math, time

import simtk.unit as unit
    
from simtk.unit import Quantity
from simtk.unit import angstrom, nanometer
from simtk.unit import femtosecond, picosecond, nanosecond, second
from simtk.unit import joule, kilojoule_per_mole
from simtk.unit import kelvin
from simtk.unit import amu, elementary_charge, mole

from multiprocessing import Process, Queue
from simtk.unit import *
from numpy import *

# Shared data
numParticlesX=3
numParticlesY=2
numParticlesZ=2
stepSize =  Quantity(1.0, femtosecond)
scaleStepSizeX=1.0
scaleStepSizeY=1.0
scaleStepSizeZ=1.0
totalTime = Quantity(0.1, nanosecond)
stepsPerReport=10000
#randomSeed=None
randomSeed=1
verbose=True

# Parameters
mass_Ar     = Quantity(39.9, amu)
q_Ar        = Quantity(0.0, elementary_charge)
sigma_Ar    = Quantity(3.350, angstrom)
epsilon_Ar  = Quantity(0.001603, kilojoule_per_mole)

class worker_data_t:
    def __init__(self):
        self.index = None

class work_unit_t:
    def __init__(self):
        self.operation = 'NONE'
        
def worker(worker_data, work_queue, result_queue):
    # DEBUG
    print "worker thread %d" % worker_data.index

    # Initialize OpenMM
    import simtk.chem.openmm as mm

    # Create OpenMM system
    system = mm.System()

    # Add particles to system
    initPositions = Quantity(angstrom)
    nb = mm.NonbondedForce()
    nb.setNonbondedMethod(mm.NonbondedForce.CutoffPeriodic)
    nb.setCutoffDistance(0.8)
    maxX=0 * angstrom
    maxY=0 * angstrom
    maxZ=0 * angstrom
    for ii in range(numParticlesX):
        for jj in range(numParticlesY):
            for kk in range(numParticlesZ):
                system.addParticle(mass_Ar)
                nb.addParticle(q_Ar, sigma_Ar, epsilon_Ar)
                x = sigma_Ar*scaleStepSizeX*ii
                y = sigma_Ar*scaleStepSizeY*jj
                z = sigma_Ar*scaleStepSizeZ*kk
                initPositions.append( Quantity((x, y, z)) )
                if x>maxX: maxX = x
                if y>maxY: maxY = y
                if z>maxZ: maxZ = z

    x = maxX+2*sigma_Ar*scaleStepSizeX
    y = maxY+2*sigma_Ar*scaleStepSizeY
    z = maxZ+2*sigma_Ar*scaleStepSizeZ
    nb.setPeriodicBoxVectors(Quantity((x,          0*angstrom, 0*angstrom)),
                             Quantity((0*angstrom,          y, 0*angstrom)),
                             Quantity((0*angstrom, 0*angstrom, z)))
    (xX, xY, xZ), (yX, yY, yZ), (zX, zY, zZ) = nb.getPeriodicBoxVectors()
    sys.stdout.flush()
    system.addForce(nb)

    # Query system
    print "thread %d : system.getNumParticles() = %d" % (worker_data.index, system.getNumParticles())

    # Create dictionary of platforms.
    platforms = dict()
    for platform_index in range( mm.Platform.getNumPlatforms() ):
        platform = mm.Platform.getPlatform(platform_index)
        platforms[ platform.getName() ] = platform
    del platform
    
    print "thread %d : platforms = %s" % (worker_data.index, str(platforms))

    # Set GPU device number on CUDA platform
    platform = platforms['Cuda']
    platform.setPropertyDefaultValue('CudaDevice', str(worker_data.index))

    # DEBUG: Select the Reference platform.
    #platform = platforms['Reference']
    
    # Service work units
    while (True):
        # Get work unit (blocking call)
        work_unit = work_queue.get()        

        print "thread %d : work unit operation '%s'" % (worker_data.index, work_unit.operation)

        # Process work unit.
        if (work_unit.operation == 'TERMINATE'):
            # Terminate
            print "thread %d : terminating." % worker_data.index
            return

        if (work_unit.operation == 'LANGEVIN'):

            # Create a Langevin integrator
            print 'thread %d : Creating integrator...' % (worker_data.index)            
            integrator = mm.LangevinIntegrator(work_unit.temperature, work_unit.frictionCoefficient, work_unit.timestep)
            print 'thread %d : integrator created.' % worker_data.index

            # Create a context
            print 'thread %d : Creating context...' % worker_data.index
            context = mm.Context(system, integrator, platform)
            print 'thread %d : Context created.' % worker_data.index
            print 'thread %d : Using %s platform' % (worker_data.index, context.getPlatform().getName())

            # Set initial coordinates and momenta
            context.setPositions(work_unit.positions)
            context.setVelocities(work_unit.velocities)

            # Run dynamics
            integrator.step(work_unit.nsteps)
        
            # Store final coordinates and velocities in result
            state = context.getState(getPositions=True,
                                     getVelocities=True,
                                     getForces=False,
                                     getEnergy=True,
                                     getParameters=False)

            result = work_unit
            result.positions = state.getPositions(asNumpy = True).copy()
            result.velocities = state.getVelocities(asNumpy = True).copy()

            # Place result in result queue
            result_queue.put(result)

            # Clean up
            del integrator, context, state

    return
    
if __name__ == '__main__':
    # Determine number of atoms
    natoms = numParticlesX * numParticlesY * numParticlesZ

    # Construct initial positions
    print "%d atoms total" % natoms
    
    positions = zeros([natoms, 3], float64)
    velocities = zeros([natoms, 3], float64)
    
    atom_index = 0
    for ii in range(numParticlesX):
        for jj in range(numParticlesY):
            for kk in range(numParticlesZ):
                x = sigma_Ar*scaleStepSizeX*ii
                y = sigma_Ar*scaleStepSizeY*jj
                z = sigma_Ar*scaleStepSizeZ*kk

                positions[atom_index,0] = x / angstrom
                positions[atom_index,1] = y / angstrom
                positions[atom_index,2] = z / angstrom
                
                atom_index += 1

    # Create work and result queues.
    work_queue = Queue()
    result_queue = Queue()

    # Initialize worker pool.
    nworkers = 2
    worker_pool = list()
    for worker_index in range(nworkers):
        # Create worker-specific data structure
        worker_data = worker_data_t()
        worker_data.index = worker_index

        # Spawn new process
        process = Process(target = worker, args = (worker_data, work_queue, result_queue))
        worker_pool.append(process)
        process.start()

    # Initialize OpenMM
    import simtk.chem.openmm as mm

    # Submit jobs
    njobs = 10
    for job_index in range(njobs):
        work_unit = work_unit_t()
        work_unit.operation = 'LANGEVIN'
        work_unit.positions = positions
        work_unit.velocities = velocities
        work_unit.temperature = 300.0 * kelvin
        work_unit.frictionCoefficient = 90.0 / picosecond
        work_unit.timestep = 1.0 * femtosecond
        work_unit.nsteps = 100

        work_unit.temperature = 300.0 
        work_unit.frictionCoefficient = 90.0
        work_unit.timestep = 1.0         

        # Submit work unit
        work_queue.put(work_unit)

    # Process results
    for job_index in range(njobs):
        # Get result.
        result = result_queue.get()
        # Just discard it.
        del result

    # TERMINATE WORKERS

    # Send poison pills
    for process in worker_pool:
        work_unit = work_unit_t()
        work_unit.operation = 'TERMINATE'
        work_queue.put(work_unit)

    # Wait for processes to terminate
    for process in worker_pool:        
        process.join()




