#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Compute energy of a system with pyopenmm AMBER loader.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import copy
import time

import numpy

import simtk.unit as units
import simtk.openmm as openmm

from alchemy import AbsoluteAlchemicalFactory, AlchemicalState
        
#=============================================================================================
# YANK class
#=============================================================================================

class Yank(object):
    """
    A class for computing receptor-ligand binding free energies through alchemical transformations.

    Options that can be set after initializaton but before run() has been called:

    data_directory (string) - destination for datafiles (default: .)
    online_analysis (boolean) - if True, will analyze data as simulations proceed (default: False)
    temperature (simtk.unit.Quantity compatible with kelvin) - temperature of simulation (default: 298 K)
    pressure (simtk.unit.Quantity compatible with atmosphere) - pressure of simulation if explicit solvent (default: 1 atm)
    niterations (integer) - number of iterations to run (default: 10)
    solvent_protocol (list of AlchemicalState) - protocol to use for turning off ligand in solvent
    complex_protocol (list of AlchemicalState) - protocol to use for turning off ligand in complex

    """

    def __init__(self, receptor=None, ligand=None, complex=None, complex_coordinates=None, output_directory=None, verbose=False):
        """
        Create a YANK binding free energy calculation object.

        ARGUMENTS
        
        receptor (simtk.openmm.System) - the receptor OpenMM system (receptor with implicit solvent forces)
        ligand (simtk.openmm.System) - the ligand OpenMM system (ligand with implicit solvent forces)
        complex_coordinates (simtk.unit.Quantity of coordinates, or list thereof) - coordinates for the complex to initialize replicas with, either a single snapshot or a list of snapshots
        output_directory (String) - the output directory to write files (default: current directory)

        OPTIONAL ARGUMENTS
        
        verbose (boolean) - if True, will give verbose output
        complex (simtk.openmm.System) - specified System will be used instead of concatenating receptor + ligand (default: None)

        NOTES

        * Explicit solvent is not yet supported.
        
        TODO

        * Automatically use a temporary directory, or prepend a unique string, for each set of output files?

        """

        # Check arguments.
        if (receptor is None) or (ligand is None):
            raise Exception("Yank must be initialized with receptor and ligand System objects.")
        if complex_coordinates is None:
            raise Exception("Yank must be initialized with at least one set of complex coordinates.")

        # Mark as not yet initialized.
        self._initialized = False

        # Set defaults for free energy calculation and protocol.
        self.verbose = False # Don't show verbose output
        self.online_analysis = False # if True, simulation will be analyzed each iteration
        self.temperature = 298.0 * units.kelvin # simulation temperature
        self.pressure = 1.0 * units.atmosphere # simulation pressure, if explicit solvent; None otherwise
        self.niterations = 2000 # number of production iterations
        self.perform_sanity_checks = True # perform some sanity checks to ensure correct results
        self.platform = None # don't specify a platform
        
        # Store deep copies of receptor and ligand.
        self.receptor = copy.deepcopy(receptor) 
        self.ligand = copy.deepcopy(ligand)

        # Don't randomize ligand by default.
        self.randomize_ligand = False

        # Create complex and store atom indices.
        if complex is None:
            # TODO: We need to strip out solvent from ligand system if it is solvated.
            if verbose: print "Combining receptor and ligand systems..."
            self.complex_pyopenmm = pyopenmm.System(self.receptor) + pyopenmm.System(self.ligand) # append ligand atoms to receptor atoms to form System object for complex
            self.complex = self.complex_pyopenmm.asSwig()
        else:
            self.complex = copy.deepcopy(complex)
            self.complex_pyopenmm = pyopenmm.System(self.complex)

        # Set up output directory.
        if output_directory is None:
            output_directory = os.getcwd()
        self.output_directory = output_directory

        # DEBUG
        if self.verbose: print "receptor has %d atoms; ligand has %d atoms" % (self.receptor.getNumParticles(), self.ligand.getNumParticles())

        # Determine whether system is periodic.
        self.is_periodic = self.complex_pyopenmm.is_periodic
        
        # Select default protocols for alchemical transformation.
        self.vacuum_protocol = AbsoluteAlchemicalFactory.defaultVacuumProtocol()
        if self.is_periodic:
            self.solvent_protocol = AbsoluteAlchemicalFactory.defaultSolventProtocolExplicit()
            self.complex_protocol = AbsoluteAlchemicalFactory.defaultComplexProtocolExplicit() 
        else:
            self.solvent_protocol = AbsoluteAlchemicalFactory.defaultSolventProtocolImplicit()
            self.complex_protocol = AbsoluteAlchemicalFactory.defaultComplexProtocolImplicit() 
        
        # DEBUG
        self.complex_protocol = AbsoluteAlchemicalFactory.defaultSolventProtocolImplicit() 
        
        # Determine atom indices in complex.
        self.receptor_atoms = range(0, self.receptor.getNumParticles()) # list of receptor atoms
        self.ligand_atoms = range(self.receptor.getNumParticles(), self.complex.getNumParticles()) # list of ligand atoms

        # Monte Carlo displacement standard deviation for encouraging rapid decorrelation of ligand in the annihilated/decoupled state.
        # TODO: Replace this by a more sophisticated MCMC move set that includes dynamics for replica propagation.
        self.displacement_sigma = 1.0 * units.nanometers # attempt to displace ligand by this stddev will be made each iteration

        # Store complex coordinates.
        # TODO: Make sure each coordinate set is packaged as a numpy array.  For now, we require the user to pass in a list of Quantity objects that contain numpy arrays.
        # TODO: Switch to a special Coordinate object (with box size) to support explicit solvent simulations.
        self.complex_coordinates = copy.deepcopy(complex_coordinates)

        # Type of restraints requested.
        self.restraint_type = 'harmonic' # default to a single harmonic restraint between the ligand and receptor

        return

    def _initialize(self):
        """
        """
        self._initialized = True

        # TODO: Run some sanity checks on arguments to see if we can initialize a valid simulation.

        # Turn off pressure if we aren't simulating a periodic system.
        if not self.is_periodic:
            self.pressure = None

        # Extract ligand coordinates.
        self.ligand_coordinates = [ coordinates[self.ligand_atoms,:] for coordinates in self.complex_coordinates ]

        # TODO: Pack up a 'protocol' for modified Hamiltonian exchange simulations. Use this instead in setting up simulations.
        self.protocol = dict()
        self.protocol['number_of_equilibration_iterations'] = 1
        self.protocol['number_of_iterations'] = self.niterations
        self.protocol['verbose'] = self.verbose
        self.protocol['timestep'] = 2.0 * units.femtoseconds
        self.protocol['collision_rate'] = 20.0 / units.picoseconds
        self.protocol['minimize'] = False # DEBUG
        self.protocol['show_mixing_statistics'] = False # this causes slowdown with iteration

        return

    def run(self):
        """
        Run a free energy calculation.
        
        TODO: Have CPUs run ligand in vacuum and implicit solvent, while GPUs run ligand in explicit solvent and complex.

        TODO: Add support for explicit solvent.  Would ligand be solvated here?  Or would ligand already be in solvent?
        
        """

        # Initialize if we haven't yet done so.
        if not self._initialized:
            self._initialize()
                    
        # Create reference thermodynamic state corresponding to experimental conditions.
        reference_state = ThermodynamicState(temperature=self.temperature, pressure=self.pressure)

        #
        # Set up ligand in vacuum simulation.
        #

        # Remove any implicit solvation forces, if present.
        vacuum_ligand = pyopenmm.System(self.ligand)
        for force in vacuum_ligand.forces:
            if type(force) in pyopenmm.IMPLICIT_SOLVATION_FORCES:
                vacuum_ligand.removeForce(force)
        vacuum_ligand = vacuum_ligand.asSwig()

        if self.verbose: print "Running vacuum simulation..."
        factory = AbsoluteAlchemicalFactory(vacuum_ligand, ligand_atoms=range(self.ligand.getNumParticles()))
        systems = factory.createPerturbedSystems(self.vacuum_protocol)
        store_filename = os.path.join(self.output_directory, 'vacuum.nc')
        vacuum_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.ligand_coordinates, store_filename, protocol=self.protocol)
        if self.platform:
            vacuum_simulation.platform = self.platform
        else:
            vacuum_simulation.platform = openmm.Platform.getPlatformByName('Reference')
        vacuum_simulation.nsteps_per_iteration = 500
        #vacuum_simulation.run()
        
        # 
        # Set up ligand in solvent simulation.
        #

        if self.verbose: print "Running solvent simulation..."
        factory = AbsoluteAlchemicalFactory(self.ligand, ligand_atoms=range(self.ligand.getNumParticles()))
        systems = factory.createPerturbedSystems(self.solvent_protocol)
        store_filename = os.path.join(self.output_directory, 'solvent.nc')
        solvent_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.ligand_coordinates, store_filename, protocol=self.protocol)
        if self.platform:
            solvent_simulation.platform = self.platform
        solvent_simulation.nsteps_per_iteration = 500
        solvent_simulation.run()
        
        #
        # Set up ligand in complex simulation.
        # 

        if self.verbose: print "Setting up complex simulation..."

        if not self.is_periodic:
            # Impose restraints to keep the ligand from drifting too far from the protein.
            import restraints
            reference_coordinates = self.complex_coordinates[0]
            if self.restraint_type == 'harmonic':
                restraints = restraints.ReceptorLigandRestraint(reference_state, self.complex, reference_coordinates, self.receptor_atoms, self.ligand_atoms)
            elif self.restraint_type == 'flat-bottom':
                restraints = restraints.FlatBottomReceptorLigandRestraint(reference_state, self.complex, reference_coordinates, self.receptor_atoms, self.ligand_atoms)
            else:
                raise Exception("restraint_type of '%s' is not supported." % self.restraint_type)
            force = restraints.getRestraintForce() # Get Force object incorporating restraints
            self.complex.addForce(force)
            self.standard_state_correction = restraints.getStandardStateCorrection() # standard state correction in kT
        
        factory = AbsoluteAlchemicalFactory(self.complex, ligand_atoms=self.ligand_atoms)
        systems = factory.createPerturbedSystems(self.complex_protocol, verbose=self.verbose)
        store_filename = os.path.join(self.output_directory, 'complex.nc')

        metadata = dict()
        metadata['standard_state_correction'] = self.standard_state_correction

        if self.randomize_ligand:
            print "Randomizing ligand positions and excluding overlapping configurations..."
            randomized_coordinates = list()
            sigma = 20.0 * units.angstrom
            close_cutoff = 3.0 * units.angstrom
            nstates = len(systems)
            for state_index in range(nstates):
                coordinates = self.complex_coordinates[numpy.random.randint(0, len(self.complex_coordinates))]
                new_coordinates = ModifiedHamiltonianExchange.randomize_ligand_position(coordinates, self.receptor_atoms, self.ligand_atoms, sigma, close_cutoff)
                randomized_coordinates.append(new_coordinates)
            self.complex_coordinates = randomized_coordinates

        complex_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.complex_coordinates, store_filename, displacement_sigma=self.displacement_sigma, mc_atoms=self.ligand_atoms, protocol=self.protocol, metadata=metadata)
        complex_simulation.nsteps_per_iteration = 2500
        if self.platform:
            complex_simulation.platform = self.platform

        # Run the simulation.
        if self.verbose: print "Running complex simulation..."
        complex_simulation.run()        
        
        return

    def run_mpi(self, mpi_comm_world, cpuid_gpuid_mapping=None):
        """
        Run a free energy calculation using MPI.
        
        ARGUMENTS
        
        mpi_comm_world - MPI 'world' communicator

        TODO

        * Make a configuration file for CPU:GPU id mapping.

        """

        # Turn off output from non-root nodes:
        if not (mpi_comm_world.rank==0):
            verbose = False

        # Make sure random number generators have unique seeds.
        seed = numpy.random.randint(sys.maxint - MPI.COMM_WORLD.size) + MPI.COMM_WORLD.rank
        numpy.random.seed(seed)

        # Specify which CPUs should be attached to specific GPUs for maximum performance.
        cpu_platform_name = 'Reference'
        gpu_platform_name = 'OpenCL'
        
        if not cpuid_gpuid_mapping:
            # TODO: Determine number of GPUs and set up simple mapping.
            cpuid_gpuid_mapping = { 0:0, 1:1, 2:2, 3:3 }

        # Choose appropriate platform for each device.
        cpuid = MPI.COMM_WORLD.rank # use default rank as CPUID (TODO: Improve this)
        #print "node '%s' MPI_WORLD rank %d/%d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
        if cpuid in cpuid_gpuid_mapping.keys():
            platform = openmm.Platform.getPlatformByName(gpu_platform_name)
            deviceid = cpuid_gpuid_mapping[cpuid]
            platform.setPropertyDefaultValue('OpenCLDeviceIndex', '%d' % deviceid) # select OpenCL device index
            platform.setPropertyDefaultValue('CudaDeviceIndex', '%d' % deviceid) # select Cuda device index
            print "node '%s' MPI_WORLD rank %d/%d cpuid %d platform %s deviceid %d" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size, cpuid, gpu_platform_name, deviceid)
        else:
            platform = openmm.Platform.getPlatformByName(cpu_platform_name)
            print "node '%s' MPI_WORLD rank %d/%d running on CPU" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)

        # Set up CPU and GPU communicators.
        gpu_process_list = filter(lambda x : x < MPI.COMM_WORLD.size, cpuid_gpuid_mapping.keys())
        if cpuid in gpu_process_list:
            this_is_gpu_process = 1 # running on a GPU
        else:
            this_is_gpu_process = 0 # running on a CPU
        comm = MPI.COMM_WORLD.Split(color=this_is_gpu_process)

        # Initialize if we haven't yet done so.
        if not self._initialized:
            self._initialize()
                    
        # Create reference thermodynamic state corresponding to experimental conditions.
        reference_state = ThermodynamicState(temperature=self.temperature, pressure=self.pressure)

        if this_is_gpu_process:

            # Run ligand in complex simulation on GPUs.
            #self.protocol['verbose'] = False # DEBUG: Suppress terminal output from ligand in solvent and vacuum simulations.

            self.standard_state_correction = 0.0 
            if not self.is_periodic: 
                # Impose restraints to keep the ligand from drifting too far from the protein.
                import restraints
                reference_coordinates = self.complex_coordinates[0]
                if self.restraint_type == 'harmonic':
                    restraints = restraints.ReceptorLigandRestraint(reference_state, self.complex, reference_coordinates, self.receptor_atoms, self.ligand_atoms)
                elif self.restraint_type == 'flat-bottom':
                    restraints = restraints.FlatBottomReceptorLigandRestraint(reference_state, self.complex, reference_coordinates, self.receptor_atoms, self.ligand_atoms)
                else:
                    raise Exception("restraint_type of '%s' is not supported." % self.restraint_type)
                force = restraints.getRestraintForce() # Get Force object incorporating restraints
                self.complex.addForce(force)
                self.standard_state_correction = restraints.getStandardStateCorrection() # standard state correction in kT

            # Create alchemically perturbed systems.
            factory = AbsoluteAlchemicalFactory(self.complex, ligand_atoms=self.ligand_atoms)
            systems = factory.createPerturbedSystems(self.complex_protocol, verbose=self.verbose)

            store_filename = os.path.join(self.output_directory, 'complex.nc')

            metadata = dict()
            metadata['standard_state_correction'] = self.standard_state_correction

            if self.randomize_ligand:
                randomized_coordinates = list()
                sigma = 20.0 * units.angstrom
                close_cutoff = 1.5 * units.angstrom
                nstates = len(systems)
                for state_index in range(nstates):
                    coordinates = self.complex_coordinates[numpy.random.randint(0, len(self.complex_coordinates))]
                    new_coordinates = ModifiedHamiltonianExchange.randomize_ligand_position(coordinates, self.receptor_atoms, self.ligand_atoms, sigma, close_cutoff)
                    randomized_coordinates.append(new_coordinates)
                self.complex_coordinates = randomized_coordinates
                
            # Setup Hamiltonian exchange simulation.
            complex_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.complex_coordinates, store_filename, displacement_sigma=self.displacement_sigma, mc_atoms=self.ligand_atoms, protocol=self.protocol, mpicomm=comm, metadata=metadata)
            complex_simulation.nsteps_per_iteration = 2500
            complex_simulation.run()        

        else:
            print "Running on cpu (node %s, rank %d / %d)" % (hostname, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
            # Run ligand in vacuum simulation on CPUs.            
            self.protocol['verbose'] = False # DEBUG: Suppress terminal output from ligand in solvent and vacuum simulations.
            
            # Remove any implicit solvation forces, if present.
            vacuum_ligand = pyopenmm.System(self.ligand)
            for force in vacuum_ligand.forces:
                if type(force) in pyopenmm.IMPLICIT_SOLVATION_FORCES:
                    vacuum_ligand.removeForce(force)
            vacuum_ligand = vacuum_ligand.asSwig()
                
            factory = AbsoluteAlchemicalFactory(vacuum_ligand, ligand_atoms=range(self.ligand.getNumParticles()))
            systems = factory.createPerturbedSystems(self.vacuum_protocol)
            store_filename = os.path.join(self.output_directory, 'vacuum.nc')
            vacuum_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.ligand_coordinates, store_filename, protocol=self.protocol, mpicomm=comm)
            vacuum_simulation.platform = openmm.Platform.getPlatformByName('Reference')
            vacuum_simulation.nsteps_per_iteration = 500
            vacuum_simulation.run()
        
            # Run ligand in solvent simulation on CPUs.
            # TODO: Have this run on GPUs if explicit solvent.

            factory = AbsoluteAlchemicalFactory(self.ligand, ligand_atoms=range(self.ligand.getNumParticles()))
            systems = factory.createPerturbedSystems(self.solvent_protocol)
            store_filename = os.path.join(self.output_directory, 'solvent.nc')
            solvent_simulation = ModifiedHamiltonianExchange(reference_state, systems, self.ligand_coordinates, store_filename, protocol=self.protocol, mpicomm=comm)
            solvent_simulation.platform = openmm.Platform.getPlatformByName('Reference')
            solvent_simulation.nsteps_per_iteration = 500
            solvent_simulation.run()

        # Wait for all nodes to finish.
        MPI.COMM_WORLD.barrier()
       
        return

    @classmethod
    def _extract_u_n(cls, ncfile):
        """
        Extract timeseries of u_n = - log q(x_n)               

        """

        # Get current dimensions.
        niterations = ncfile.variables['energies'].shape[0]
        nstates = ncfile.variables['energies'].shape[1]
        natoms = ncfile.variables['energies'].shape[2]

        # Extract energies.
        energies = ncfile.variables['energies']
        u_kln_replica = numpy.zeros([nstates, nstates, niterations], numpy.float64)
        for n in range(niterations):
            u_kln_replica[:,:,n] = energies[n,:,:]

        # Deconvolve replicas
        u_kln = numpy.zeros([nstates, nstates, niterations], numpy.float64)
        for iteration in range(niterations):
            state_indices = ncfile.variables['states'][iteration,:]
            u_kln[state_indices,:,iteration] = energies[iteration,:,:]

        # Compute total negative log probability over all iterations.
        u_n = numpy.zeros([niterations], numpy.float64)
        for iteration in range(niterations):
            u_n[iteration] = numpy.sum(numpy.diagonal(u_kln[:,:,iteration]))

        return u_n

    def analyze(self, verbose=False):
        """
        Analyze the results of a YANK free energy calculation.

        OPTIONAL ARGUMENTS

        verbose (bool) - if True, will print verbose progress information (default: False)

        """

        import analyze
        import pymbar
        import timeseries
        import netCDF4 as netcdf

        # Storage for results.
        results = dict()

        if verbose: print "Analyzing simulation data..."

        # Process each netcdf file in output directory.
        source_directory = self.output_directory
        phases = ['vacuum', 'solvent', 'complex']
        for phase in phases:
            # Construct full path to NetCDF file.
            fullpath = os.path.join(source_directory, phase + '.nc')

            # Skip if the file doesn't exist.
            if (not os.path.exists(fullpath)): continue

            # Open NetCDF file for reading.
            print "Opening NetCDF trajectory file '%(fullpath)s' for reading..." % vars()
            ncfile = netcdf.Dataset(fullpath, 'r')

            # Read dimensions.
            niterations = ncfile.variables['positions'].shape[0]
            nstates = ncfile.variables['positions'].shape[1]
            natoms = ncfile.variables['positions'].shape[2]
            if verbose: print "Read %(niterations)d iterations, %(nstates)d states" % vars()

            # Choose number of samples to discard to equilibration.
            u_n = self._extract_u_n(ncfile)
            [nequil, g_t, Neff_max] = timeseries.detectEquilibration(u_n)
            if verbose: print [nequil, Neff_max]

            # Examine mixing statistics.
            analyze.show_mixing_statistics(ncfile, cutoff=0.05, nequil=nequil)

            # Estimate free energies.
            (Deltaf_ij, dDeltaf_ij) = analyze.estimate_free_energies(ncfile, ndiscard=nequil)
    
            # Estimate average enthalpies
            (DeltaH_i, dDeltaH_i) = analyze.estimate_enthalpies(ncfile, ndiscard=nequil)
    
            # Accumulate free energy differences
            entry = dict()
            entry['DeltaF'] = Deltaf_ij[0,nstates-1] 
            entry['dDeltaF'] = dDeltaf_ij[0,nstates-1]
            entry['DeltaH'] = DeltaH_i[nstates-1] - DeltaH_i[0]
            entry['dDeltaH'] = numpy.sqrt(dDeltaH_i[0]**2 + dDeltaH_i[nstates-1]**2)
            results[phase] = entry

            # Get temperatures.
            ncvar = ncfile.groups['thermodynamic_states'].variables['temperatures']
            temperature = ncvar[0] * units.kelvin
            kT = analyze.kB * temperature

            # Close input NetCDF file.
            ncfile.close()
        
        return results

#=============================================================================================
# Command-line driver
#=============================================================================================

def read_amber_crd(filename, natoms_expected, verbose=False):
    """
    Read AMBER coordinate file.

    ARGUMENTS

    filename (string) - AMBER crd file to read
    natoms_expected (int) - number of atoms expected

    RETURNS

    coordinates (numpy-wrapped simtk.unit.Quantity with units of distance) - a single read coordinate set

    TODO

    Automatically handle box vectors?

    """

    if verbose: print "Reading cooordinate sets from '%s'..." % filename
    
    # Read coordinates.
    import simtk.pyopenmm.amber.amber_file_parser as amber
    coordinates = amber.readAmberCoordinates(filename, asNumpy=True)

    # Check to make sure number of atoms match expectation.
    natoms = coordinates.shape[0]
    if natoms != natoms_expected:
        raise Exception("Read coordinate set from '%s' that had %d atoms (expected %d)." % (filename, natoms, natoms_expected))

    return coordinates

def read_openeye_crd(filename, natoms_expected, verbose=False):
    """
    Read one or more coordinate sets from a file that OpenEye supports.

    ARGUMENTS
    
    filename (string) - the coordinate filename to be read
    natoms_expected (int) - number of atoms expected

    RETURNS
    
    coordinates_list (list of numpy array of simtk.unit.Quantity) - list of coordinate sets read

    """

    if verbose: print "Reading cooordinate sets from '%s'..." % filename

    import openeye.oechem as oe
    imolstream = oe.oemolistream()
    imolstream.open(filename)
    coordinates_list = list()
    for molecule in imolstream.GetOEGraphMols():
        oecoords = molecule.GetCoords() # oecoords[atom_index] is tuple of atom coordinates, in angstroms
        natoms = len(oecoords) # number of atoms
        if natoms != natoms_expected:
            raise Exception("Read coordinate set from '%s' that had %d atoms (expected %d)." % (filename, natoms, natoms_expected))
        coordinates = units.Quantity(numpy.zeros([natoms,3], numpy.float32), units.angstroms) # coordinates[atom_index,dim_index] is coordinates of dim_index dimension of atom atom_index
        for atom_index in range(natoms):
            coordinates[atom_index,:] = units.Quantity(numpy.array(oecoords[atom_index]), units.angstroms)
        coordinates_list.append(coordinates)

    if verbose: print "%d coordinate sets read." % len(coordinates_list)
    
    return coordinates_list

def read_pdb_crd(filename, natoms_expected, verbose=False):
    """
    Read one or more coordinate sets from a PDB file.
    Multiple coordinate sets (in the form of multiple MODELs) can be read.

    ARGUMENTS

    filename (string) - name of the file to be read
    natoms_expected (int) - number of atoms expected

    RETURNS

    coordinates_list (list of numpy array of simtk.unit.Quantity) - list of coordinate sets read

    """
    
    # Open PDB file.
    pdbfile = open(filename, 'r')

    # Storage for sets of coordinates.
    coordinates_list = list()
    coordinates = units.Quantity(numpy.zeros([natoms_expected,3], numpy.float32), units.angstroms) # coordinates[atom_index,dim_index] is coordinates of dim_index dimension of atom atom_index

    # Extract the ATOM entries.
    # Format described here: http://bmerc-www.bu.edu/needle-doc/latest/atom-format.html
    atom_index = 0
    atoms = list()
    for line in pdbfile:
        if line[0:6] == "MODEL ":
            # Reset atom counter.
            atom_index = 0
            atoms = list()
        elif (line[0:6] == "ENDMDL") or (line[0:6] == "END   "):
            # Store model.
            coordinates_list.append(copy.deepcopy(coordinates))
            coordinates *= 0
            atom_index = 0
        elif line[0:6] == "ATOM  ":
            # Parse line into fields.
            atom = dict()
            atom["serial"] = line[6:11]
            atom["atom"] = line[12:16]
            atom["altLoc"] = line[16:17]
            atom["resName"] = line[17:20]
            atom["chainID"] = line[21:22]
            atom["Seqno"] = line[22:26]
            atom["iCode"] = line[26:27]
            atom["x"] = line[30:38]
            atom["y"] = line[38:46]
            atom["z"] = line[46:54]
            atom["occupancy"] = line[54:60]
            atom["tempFactor"] = line[60:66]
            atoms.append(atom)
            coordinates[atom_index,:] = units.Quantity(numpy.array([float(atom["x"]), float(atom["y"]), float(atom["z"])]), units.angstroms)
            atom_index += 1

    # Close PDB file.
    pdbfile.close()

    # Append if we haven't dumped coordinates yet.
    if (atom_index == natoms_expected):
        coordinates_list.append(copy.deepcopy(coordinates))

    # Return coordinates.
    return coordinates_list

def compute_energy_by_force(system, coordinates, platform_name='Reference'):
    
    # Set Force groups.
    for force_index in range(system.getNumForces()):
        system.getForce(force_index).setForceGroup(force_index)

    # Create Context.
    platform = openmm.Platform.getPlatformByName(platform_name)
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)

    # Get potential energy.
    state = context.getState(getEnergy=True)
    potential_energy = state.getPotentialEnergy()
    print "%24.5f kcal/mol" % (potential_energy / units.kilocalories_per_mole)

    # Get energy by force group.
    for force_index in range(system.getNumForces()):
        force_mask = 1 << force_index
        state = context.getState(getEnergy=True, groups=force_mask)
        potential_energy = state.getPotentialEnergy()
        print "force %8d : force_mask = %8o | %24.5f kcal/mol" % (force_index, force_mask, potential_energy / units.kilocalories_per_mole)

    del state, context, integrator
    return potential_energy

def compute_energy(system, coordinates, platform_name='OpenCL'):
    
    # Create Context.
    platform = openmm.Platform.getPlatformByName(platform_name)
    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(system, integrator, platform)
    context.setPositions(coordinates)

    # Get potential energy.
    state = context.getState(getEnergy=True)
    potential_energy = state.getPotentialEnergy()
    print "%24.5f kcal/mol" % (potential_energy / units.kilocalories_per_mole)

    del state, context, integrator
    return potential_energy

def show_forces(system):
    # Show force components.
    nforces = system.getNumForces()
    for force_index in range(nforces):
        force = system.getForce(force_index)
        print "force %5d : %s" % (force_index, str(force))
    return

def serialize_system(system, filename):
    outfile = open(filename, 'w')
    serialized = system.__getstate__()
    outfile.write(serialized + '\n')
    outfile.close()
    return

if __name__ == '__main__':    
    # Initialize command-line argument parser.

    """
    USAGE

    python compute_energy.py directory

    EXAMPLES

    python compute_energy.py ../../src/examples/p-xylene/

    """

    # Parse command-line arguments.
    import sys
    directory = sys.argv[1]
    print "Using specified directory: '%s'" % directory

    # Create System objects for ligand and complex.
    #import simtk.pyopenmm.amber.amber_file_parser as amber
    import amber_file_parser as amber
    ligand_prmtop_filename = os.path.join(directory, 'ligand.prmtop')
    print "Reading AMBER ligand prmtop from '%s'..." % ligand_prmtop_filename
    gbmodel = 'OBC'
    ligand_system = amber.readAmberSystem(prmtop_filename=ligand_prmtop_filename, shake='h-bonds', gbmodel=gbmodel, flexibleConstraints=False)
    complex_prmtop_filename = os.path.join(directory, 'complex.prmtop')
    print "Reading AMBER complex prmtop from '%s'..." % complex_prmtop_filename
    complex_system = amber.readAmberSystem(prmtop_filename=complex_prmtop_filename, shake='h-bonds', gbmodel=gbmodel, flexibleConstraints=False)
    
    # Read ligand and complex coordinates.
    print "Reading AMBER coordinate files..."
    ligand_coordinate_filename = os.path.join(directory, 'ligand.crd')
    ligand_coordinates = amber.readAmberCoordinates(ligand_coordinate_filename, asNumpy=True)
    complex_coordinate_filename = os.path.join(directory, 'complex.crd')
    complex_coordinates = amber.readAmberCoordinates(complex_coordinate_filename, asNumpy=True)

    serialize_system(complex_system, 'complex-python.xml')

    show_forces(complex_system)

    # Compute ligand and complex energy.
    platform_name = 'OpenCL'
    timestep = 1.0 * units.femtoseconds
    platform = openmm.Platform.getPlatformByName(platform_name)

    integrator = openmm.VerletIntegrator(timestep)
    context = openmm.Context(ligand_system, integrator, platform)
    context.setPositions(ligand_coordinates)
    state = context.getState(getEnergy=True)
    ligand_potential_energy = state.getPotentialEnergy()
    del state, context, integrator
    print "ligand potential energy: %.3f kcal/mol" % (ligand_potential_energy / units.kilocalories_per_mole)

    compute_energy(ligand_system, ligand_coordinates)
    compute_energy(complex_system, complex_coordinates)
    
    ligand_atoms = range(complex_system.getNumParticles() - ligand_system.getNumParticles(), complex_system.getNumParticles()) # list of ligand atoms    
    factory = AbsoluteAlchemicalFactory(complex_system, ligand_atoms=ligand_atoms)
    alchemical_states = list()
    
    alchemical_states.append(AlchemicalState(0.00, 1.00, 1.00, 1.)) # fully interacting
    alchemical_states.append(AlchemicalState(0.00, 0.50, 1.00, 1.)) # half electrostatics
    alchemical_states.append(AlchemicalState(0.00, 0.00, 1.00, 1.)) # no electrostatics
    alchemical_states.append(AlchemicalState(0.00, 0.50, 0.50, 1.)) # half LJ
    alchemical_states.append(AlchemicalState(0.00, 0.00, 0.00, 1.)) # noninteracting

    systems = factory.createPerturbedSystems(alchemical_states, verbose=False)
    
    for (index, system) in enumerate(systems):
        print "Alchemical state %5d : " % index,
        compute_energy(system, complex_coordinates)
        
