#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
OpenMM worker process pool for servicing requests from a work queue

DESCRIPTION

This module provides the capability for managing a pool of worker processes that service
a work queue for running OpenMM jobs on local or remote machines.  These processes are
intended to make use of GPU resources.

EXAMPLES

Create a worker pool on localhost.

>>> worker_pool = WorkerPool('localhost')
>>> work_unit = WorkUnit(operation = )
>>> worker_pool.shutdown()                     # shut down worker pool (happens automatically on cleanup)

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import numpy

#=============================================================================================
# Exceptions
#=============================================================================================

class NotImplementedException(Exception):
    """
    Exception denoting that the requested feature has not yet been implemented.

    """

    print "Feature not implemented."
    return

#=============================================================================================
# Work unit data.
#=============================================================================================

class WorkUnit(object):
    """
    Work unit for Worker object.

    >>> work_unit = WorkUnit(operation='LANGEVIN')
    
    """    

    operation = "NONE"

    def __init__(self, operation='NONE'):
        """\
        Initialize work unit.
        """

        if operation: self.operation = operation

        return

    def repr(self):
        """\
        Produce a human-readable string representation of work unit.
        """

        string = ""
        string += "Work unit\n"
        string += "operation: %s\n" % operation

        return string

#=============================================================================================
# Worker data
#=============================================================================================

class WorkerData(object):
    """\
    Worker-specific information for instantiating a Worker object.
    """    

    def __init__(self, uid=None):
        """\
        Initialize worker-specific information structure.
        
        OPTIONAL ARGUMENTS
          uid - unique identifier for this Worker    
        
        """

        if uid:
            self.uid = uid
        else:
            uid = id(self) # assign unique ID automatically
        
        return

#=============================================================================================
# Worker
#=============================================================================================

class Worker(object):
    """\
    GPU worker process for servicing requests for OpenMM operations from a work queue.
    """

    def __init__(self, worker_data, work_queue, result_queue):
        """
        Initialize the Worker object, associating it with work and result queues.

        ARGUMENTS
          worker_data (WorkerData) - worker-specific information
          work_queue (multiprocessing.Queue) - work queue to monitor and service
          result_queue (multiprocessing.Queue) - result queue where results are to be despoited
          
        """    

        # Store configuration data.
        self.worker_data = worker_data # NOTE: Do we want to make deep copies?
        self.work_queue = work_queue
        self.result_queue = result_queue
        
        # Store some worker-specific data for convenience.
        import socket        
        self.hostname = socket.gethostname()
        self.ip_address = socket.gethostbyname(hostname)
        self.uid = worker_data.uid # unique identifier
        self.verbose = worker_data.verbose # verbosity

        return
            
    def run():
        """
        Begin servicing the work queue until we explicitly receive a poison pill.
        
        """

        if verbose: print "worker %s : starting on host %s (%s)" % (uid, hostname, ip_address)
        
        # Create dictionary of platforms.
        platforms = dict()
        for platform_index in range( mm.Platform.getNumPlatforms() ):
            platform = mm.Platform.getPlatform(platform_index)
            platforms[ platform.getName() ] = platform
        del platform

        if verbose: print "worker %s : found platforms = %s" % (uid, str(platforms))
            
        # Set GPU device number on Cuda platform
        # TODO: Die gracefully if no Cuda platform found.
        platform = platforms['Cuda']
        platform.setPropertyDefaultValue('CudaDevice', str(worker_data.index))

        # Service work units indefinitely, until we receive a poison pill.
        while (True):
            # Get work unit (blocking call)
            work_unit = work_queue.get()

            if verbose: print "worker %s : work unit %s" % (uid, repr(work_unit))

            if verbose: print "worker %s : work unit operation '%s'" % (uid, work_unit.operation)

            # Process work unit.
            if (work_unit.operation == 'TERMINATE'):
                # Terminate
                if verbose: print "worker %s : terminating." % uid
                return
            if (work_unit.operation == 'LANGEVIN'):

                # Create a Langevin integrator
                print 'worker %s : Creating integrator...' % (uid)            
                integrator = mm.LangevinIntegrator(work_unit.temperature, work_unit.frictionCoefficient, work_unit.timestep)
                print 'worker %s : integrator created.' % uid

                # Create a context
                print 'worker %s : Creating context...' % uid
                context = mm.Context(system, integrator, platform)
                print 'worker %s : Context created.' % uid
                print 'worker %s : Using %s platform' % (uid, context.getPlatform().getName())

                # Set initial coordinates and momenta
                context.setPositions(work_unit.positions)
                context.setVelocities(work_unit.velocities)

                # Run dynamics
                integrator.step(work_unit.nsteps)

                # Store final coordinates and velocities in result
                state = context.getState(getPositions=True,
                                         getVelocities=True,
                                         getForces=False,
                                         getEnergy=True,
                                         getParameters=False)

                result = work_unit
                result.positions = state.getPositions(asNumpy = True)
                result.velocities = state.getVelocities(asNumpy = True)

                # Place result in result queue
                result_queue.put(result)

        return

#=============================================================================================
# Worker pool
#=============================================================================================

class WorkerPool(object):
    """
    Manager for pool of workers on local and remote machines.

    EXAMPLES

    Initialize a worker pool on the local machine.

    >>> worker_pool = WorkerPool(nworkers=2) # create 2 worker processes on localhost
    >>> worker_pool.shutdown() # explicitly shut down pool

    Shutdown automatically occurs on object deletion.

    >>> worker_pool = WorkerPool(nworkers=2)
    >>> del worker_pool        
    
    Add a host to the worker pool.

    >>> worker_pool = WorkerPool(nworkers=1)
    >>> worker_pool.addHost('localhost')
    >>> del worker_pool
    
    """    

    def __init__(self, nworkers=None, hostlist=None):
        """
        Create a pool of workers.
        
        OPTIONAL ARGUMENTS
          nworkers (int) - number of workers to spawn
          hostlist (list of strings) - list of hostnames on which workers can be spawned
        
        """
        raise NotImplementedException()
        return

    def __del__(self):
        """
        Clean up worker pool.
        
        """
        # Send poison pill to terminate all processes.
        
        return
    
    def addHost(self, hostname):
        """
        Add given host to pool of workers.

        ARGUMENTS
          hostname (string) - hostname to be added to worker pool
          
        """
        raise NotImplementedException()
        return

    def removeHost(self, hostname):
        """
        Remove given host from pool of workers.

        ARGUMENTS
          hostname (string) - hostname to be removed from worker pool
          
        """
        raise NotImplementedException()
        return

    def getWorkQueue(self):
        """
        Return the work queue.

        RETURNS
          work_queue (multiprocessing.Queue) - work queue for this worker pool
          
        """
        raise NotImplementedException()
        return
    
    def getResultQueue(self):
        """\
        Return the result queue.

        RETURNS
          result_queue (multiprocessing.Queue) - result queue for this worker pool
          
        """
        raise NotImplementedException()
        return

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    import doctest
    doctest.testmod()

