#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test various combinations of OpenMM Force objects to ensure they work correctly in combinations.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import numpy
import simtk.unit as units
import simtk.openmm as openmm

#=============================================================================================
# SUBROUTINES
#=============================================================================================

# These settings control what tolerance is allowed between platforms and the Reference platform.
ENERGY_TOLERANCE = 0.06*units.kilocalories_per_mole # energy difference tolerance
FORCE_RMSE_TOLERANCE = 0.06*units.kilocalories_per_mole/units.nanometer # per-particle force root-mean-square error tolerance

def read_alpha_carbons(pdb_filename):
    """
    Read alpha carbon coordinates from specified PDB file.

    ARGUMENTS

    pdb_filename (string) - name of PDB file to read coordinates from

    RETURNS

    coordinates (Nx3 numpy array wrapped in simtk.unit.Quantity of units distance) - coordinates of alpha carbons

    """

    infile = open(pdb_filename, 'r')
    coordinate_list = list() # temporary storage for coordinates
    for line in infile:
        if (line[0:6] == 'ATOM  ') and (line[12:16] == ' CA '):
            x = float(line[30:38])
            y = float(line[38:46])
            z = float(line[46:54])
            coordinate_list.append((x,y,z))

    # Build coordinates in numpy array.
    natoms = len(coordinate_list)
    coordinates = units.Quantity(numpy.zeros([natoms,3], numpy.float32), units.angstroms)
    for atom_index in range(natoms):    
        coordinates[atom_index,:] = units.Quantity(numpy.array(coordinate_list[atom_index], numpy.float32), units.angstroms)

    return coordinates

def compute_potential_and_force(system, coordinates, platform):
    """
    Compute the energy and force for the given system and coordinates in the designated platform.

    ARGUMENTS

    system (simtk.openmm.System) - the system for which the energy is to be computed
    coordinates (simtk.unit.Quantity of Nx3 numpy.array in units of distance) - coordinates for which energy and force are to be computed
    platform (simtk.openmm.Platform) - platform object to be used to compute the energy and force

    RETURNS

    potential (simtk.unit.Quantity in energy/mole) - the potential
    force (simtk.unit.Quantity of Nx3 numpy.array in units of energy/mole/distance) - the force

    """

    # Create a Context.
    kB = units.BOLTZMANN_CONSTANT_kB
    temperature = 298.0 * units.kelvin
    kT = kB * temperature
    beta = 1.0 / kT
    collision_rate = 90.0 / units.picosecond
    timestep = 1.0 * units.femtosecond    
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)
    # Set positions
    context.setPositions(coordinates)
    # Evaluate the potential energy.
    state = context.getState(getEnergy=True, getForces=True)
    potential = state.getPotentialEnergy()
    force = state.getForces(asNumpy=True)
    # Clean up.
    del context, integrator

    return [potential, force]

def BuildSystem(forcenames, coordinates, mm=None):
    """
    Build a system containing the given list of force terms.

    ARGUMENTS

    forcenames (list of strings) - list of force names to instantiate
    coordinates

    OPTIONAL ARGUMENTS

    mm (simtk.openmm implementation) - implementation to use
    
    RETURNS

    system (simtk.chem.System) - the system object
    coordinates (simtk.unit.Quantity of distance, Nx3 numpy array) - coordinates of all particles in system

    EXAMPLES
    
    Create a 128-particle system containing NonbondedForce and HarmonicBondForce terms

    >>> [system, coordinates] = BuildSystem(['NonbondedForce', 'HarmonicBondForce'], 128)

    """

    # Use pyOpenMM by default.
    if mm is None:
        import simtk.openmm
        mm = simtk.openmm

    # Determine number of particles.
    nparticles = coordinates.shape[0]

    # Determine average pseudobond length.
    r_ab = (coordinates[0:-1,:] - coordinates[1:,:])
    avgbond = (numpy.sqrt(((r_ab / units.angstroms)**2).sum(1))).mean() * units.angstroms
    #print "avgbond = %f A" % (avgbond / units.angstroms)
    
    # Determine average pseudoangle.
    r_ba = (coordinates[0:-2,:] - coordinates[1:-1,:]) / units.angstroms
    r_bc = (coordinates[2:,:] - coordinates[1:-1,:]) / units.angstroms
    cos_theta = (r_ba * r_bc).sum(1) / numpy.sqrt((r_ba * r_ba).sum(1)) / numpy.sqrt((r_bc * r_bc).sum(1))
    avgtheta = numpy.arccos(cos_theta).mean() * units.radians
    #print "rmstheta = %f degrees" % (avgtheta / units.degrees)
    
    # Create an empty system object.
    system = mm.System()

    # Add particles to the system.
    mass = 12.0 * units.amu # arbitrary reference mass
    for particle_index in range(nparticles):
        system.addParticle(mass)
        
    # Create and add force objects.
    for forcename in forcenames:
        
        if forcename == 'NonbondedForce':
            # Parameters.
            charge      = 0.1 * units.elementary_charge 
            sigma       = 3.330445 * units.angstrom
            epsilon     = 0.002772 * units.kilocalorie_per_mole 
            # Create force.
            force = mm.NonbondedForce()
            for particle_index in range(nparticles):
                force.addParticle(charge, sigma, epsilon)
            system.addForce(force)

        if forcename == 'NonbondedSoftcoreForce':
            # Parameters.
            charge      = 0.1 * units.elementary_charge 
            sigma       = 3.330445 * units.angstrom
            epsilon     = 0.002772 * units.kilocalorie_per_mole
            vdw_lambda  = 0.2
            # Create force.
            force = mm.NonbondedSoftcoreForce()
            for particle_index in range(nparticles):
                force.addParticle(charge, sigma, epsilon, vdw_lambda)
            system.addForce(force)

        if forcename == 'CustomNonbondedForce':
            # Parameters.
            charge      = 0.1 * units.elementary_charge 
            sigma       = 3.330445 * units.angstrom
            epsilon     = 0.002772 * units.kilocalorie_per_mole
            energy_expression = 'C*chargeprod/r + 4*epsilon*((sigma/r)^12-(sigma/r)^6); chargeprod=charge1*charge2; epsilon=sqrt(epsilon1*epsilon2); sigma=0.5*(sigma1+sigma2)'
            C = 332.0 * units.kilocalories_per_mole / units.elementary_charge**2 * units.angstrom 
            # Create force.
            force = mm.CustomNonbondedForce(energy_expression)
            force.addGlobalParameter('C', C)        
            force.addPerParticleParameter('charge')
            force.addPerParticleParameter('sigma')
            force.addPerParticleParameter('epsilon')            
            parameters = [charge, sigma, epsilon]
            for particle_index in range(nparticles):
                force.addParticle(parameters)
            system.addForce(force)
            
        if forcename == 'GBSAOBCForce':
            # Parameters.
            charge      = 0.1 * units.elementary_charge
            radius      = 1.5 * units.angstroms
            gbscale     = 0.8
            # Create force.
            force = mm.GBSAOBCForce()
            for particle_index in range(nparticles):
                force.addParticle(charge, radius, gbscale)
            system.addForce(force)

        if forcename == 'GBSAOBCSoftcoreForce':
            # Parameters.
            charge      = 0.1 * units.elementary_charge
            radius      = 1.5 * units.angstroms
            gbscale     = 0.8
            gb_lambda   = 0.2
            # Create force.
            force = mm.GBSAOBCSoftcoreForce()
            for particle_index in range(nparticles):
                force.addParticle(charge, radius, gbscale, gb_lambda)
            system.addForce(force)

        if forcename == 'HarmonicBondForce':
            # Parameters.
            r0 = avgbond
            K = 10.0 * units.kilocalories_per_mole / units.nanometers**2
            # Create force.
            force = mm.HarmonicBondForce()
            for particle_index in range(nparticles-1):
                force.addBond(particle_index, particle_index+1, r0, K)
            system.addForce(force)

        if forcename == 'HarmonicAngleForce':
            # Parmeters
            theta0 = avgtheta
            K = 10.0 * units.kilocalories_per_mole / units.radians**2
            # Add a restrining potential centered at the origin.
            force = mm.HarmonicAngleForce()
            for particle_index in range(nparticles-2):
                i = particle_index
                j = particle_index + 1
                k = particle_index + 2
                force.addAngle(i, j, k, theta0, K)
            system.addForce(force)
            
        if forcename == 'CustomBondForce':
            # Parameters.
            r0 = avgbond
            K = 10.0 * units.kilocalories_per_mole / units.nanometers**2
            energy_expression = '(K/2.0) * (r-r0)^2'
            # Create force.
            force = mm.CustomBondForce(energy_expression)
            force.addPerBondParameter('r0')            
            force.addPerBondParameter('K')
            parameters = [r0, K]
            for particle_index in range(nparticles-1):
                force.addBond(particle_index, particle_index+1, parameters)
            system.addForce(force)
            
        if forcename == 'CustomAngleForce':
            # Parameters.
            theta0 = avgtheta
            K = 10.0 * units.kilocalories_per_mole / units.radians**2
            energy_expression = '(K/2.0) * (theta-theta0)^2'
            # Create force.
            force = mm.CustomAngleForce(energy_expression)
            force.addPerAngleParameter('theta0')            
            force.addPerAngleParameter('K')
            parameters = [theta0, K]
            for particle_index in range(nparticles-2):
                i = particle_index
                j = particle_index + 1
                k = particle_index + 2
                force.addAngle(i, j, k, parameters)
            system.addForce(force)

        if forcename == 'CustomExternalForce':
            # Parmeters
            K = 1.0 * units.kilojoules_per_mole / units.nanometers**2
            # Add a restrining potential centered at the origin.
            force = mm.CustomExternalForce('(K/2.0) * (x^2 + y^2 + z^2)')
            force.addGlobalParameter('K', K)
            for particle_index in range(nparticles):
                force.addParticle(particle_index, [])
            system.addForce(force)
            
    return (system, coordinates)

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

# Name of all forces to test.
available_force_names = ['NonbondedForce', 'GBSAOBCForce', 'HarmonicBondForce', 'HarmonicAngleForce', 'CustomNonbondedForce', 'CustomBondForce', 'CustomAngleForce', 'CustomExternalForce', 'NonbondedSoftcoreForce', 'GBSAOBCSoftcoreForce']
                        
# PDB file to take alpha carbons from
pdb_filename = os.path.join(os.getenv('PYOPENMM_SOURCE_DIR'), 'test', 'additional-tests', 'systems', 'T4-lysozyme-L99A-implicit', 'receptor.pdb')

if __name__ == "__main__":
    import doctest

    debug = False # Don't display extra debug information.

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Construct all combinations of systems to test.
    print "Building systems with all combinations of forces..."
    alpha_carbons = read_alpha_carbons(pdb_filename)
    print "%d atoms read" % (alpha_carbons.shape[0])
    testsystems = list()
    navailableforces = len(available_force_names)
    for i in range(navailableforces):
        for j in range(i+1, navailableforces):
            # Retrieve force names.
            forcename1 = available_force_names[i]
            forcename2 = available_force_names[j]

            # Create system
            forcenames = [forcename1, forcename2] # list of forces to use in this system
            name = "[%s + %s]" % (forcename1, forcename2)
            [system, coordinates] = BuildSystem(forcenames, alpha_carbons)
            testsystems.append( (name, system, coordinates) )
            
    # Compute energy error made on all test systems for other platforms.
    # Make a count of how often set tolerance is exceeded.
    tests_failed = 0 # number of times tolerance is exceeded
    tests_passed = 0 # number of times tolerance is not exceeded
    print "%32s %16s          %16s          %16s          %16s" % ("platform", "potential", "error", "force mag", "rms error")    
    reference_platform = openmm.Platform.getPlatformByName("Reference")    
    for (name, system, coordinates) in testsystems:

        try:
            [reference_potential, reference_force] = compute_potential_and_force(system, coordinates, reference_platform)
        except:
            print "Caught exception."
            continue

        print "%s" % name
        for platform_index in range(openmm.Platform.getNumPlatforms()):
            platform = openmm.Platform.getPlatform(platform_index)
            platform_name = platform.getName()
            print "%32s " % platform_name,
            try:
                [platform_potential, platform_force] = compute_potential_and_force(system, coordinates, platform)
            except:
                print "Caught exception."
                continue

            # Compute error in potential.
            potential_error = platform_potential - reference_potential

            # Compute per-atom RMS (magnitude) and RMS error in force.
            force_unit = units.kilocalories_per_mole / units.nanometers
            natoms = system.getNumParticles()
            force_mse = (((reference_force - platform_force) / force_unit)**2).sum() / natoms * force_unit**2
            force_rmse = units.sqrt(force_mse)

            force_ms = ((platform_force / force_unit)**2).sum() / natoms * force_unit**2
            force_rms = units.sqrt(force_ms)

            #print "%32s %16.6f kcal/mol %16.6f kcal/mol %16.6f kcal/mol/nm %16.6f kcal/mol/nm" % (platform_name, platform_potential / units.kilocalories_per_mole, potential_error / units.kilocalories_per_mole, force_rms / force_unit, force_rmse / force_unit)
            print "%16.6f kcal/mol %16.6f kcal/mol %16.6f kcal/mol/nm %16.6f kcal/mol/nm" % (platform_potential / units.kilocalories_per_mole, potential_error / units.kilocalories_per_mole, force_rms / force_unit, force_rmse / force_unit)            
            
            # Mark whether tolerance is exceeded or not.
            test_success = True
            if numpy.isnan(platform_potential / ENERGY_TOLERANCE):
                test_success = False
                print "%32s WARNING: Potential energy is nan for this platform." % ("")
            if abs(potential_error) > ENERGY_TOLERANCE:
                test_success = False
                print "%32s WARNING: Potential energy error (%.3f kcal/mol) exceeds tolerance (%.3f kcal/mol)." % ("", potential_error/units.kilocalories_per_mole, ENERGY_TOLERANCE/units.kilocalories_per_mole)
            if numpy.isnan(force_rmse / FORCE_RMSE_TOLERANCE):
                test_success = False
                print "%32s WARNING: Force contains nan for this platform." % ("")
            if abs(force_rmse) > FORCE_RMSE_TOLERANCE:
                test_success = False
                print "%32s WARNING: Force RMS error (%.3f kcal/mol) exceeds tolerance (%.3f kcal/mol)." % ("", force_rmse/force_unit, FORCE_RMSE_TOLERANCE/force_unit)                
                if debug:
                    for atom_index in range(natoms):
                        for k in range(3):
                            print "%8.3f" % (reference_force[atom_index,k]/force_unit),
                        print " : ",
                        for k in range(3):
                            print "%8.3f" % (platform_force[atom_index,k]/force_unit),
                        print ""

            if test_success:
                tests_passed += 1
            else:
                print "Test failed."
                tests_failed += 1

    print "%d tests failed" % tests_failed
    print "%d tests passed" % tests_passed
            
    if (tests_failed > 0):
        # Signal failure of test.
        sys.exit(1)
    else:
        sys.exit(0)
