#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test all testsystems on different platforms to make sure errors in potential energy are small.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import numpy
import math

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

#=============================================================================================
# SUBROUTINES
#=============================================================================================

def compute_potential(system, coordinates, platform):
    """
    Compute the energy of the given system and coordinates in the designated platform.

    ARGUMENTS

    system
    coordinates
    platform

    RETURNS

    potential (simtk.unit.Quantity in energy/mole)

    """

    # Create a Context.
    kB = units.BOLTZMANN_CONSTANT_kB
    temperature = 298.0 * units.kelvin
    kT = kB * temperature
    beta = 1.0 / kT
    collision_rate = 90.0 / units.picosecond
    timestep = 1.0 * units.femtosecond    
    integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    context = openmm.Context(system, integrator, platform)
    # Set positions
    context.setPositions(coordinates)
    # Evaluate the potential energy.
    state = context.getState(getEnergy=True)
    potential = state.getPotentialEnergy()

    return potential

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    import doctest

    # List all available platforms
    print "Available platforms:"
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print "%5d %s" % (platform_index, platform.getName())
    print ""

    # Test all systems on Reference platform.
    platform = openmm.Platform.getPlatformByName("Reference")
    print 'Testing Reference platform...'
    doctest.testmod()    

    # Make a list of all test system constructors.
    import testsystems
    test_systems = [ (name, getattr(testsystems, name)) for name in dir(testsystems) if callable(getattr(testsystems, name)) ]

    # Compute energy error made on all test systems for other platforms.
    reference_platform = openmm.Platform.getPlatformByName("Reference")    
    for (test_system_name, test_system_constructor) in test_systems:
        [system, coordinates] = test_system_constructor()
        reference_potential = compute_potential(system, coordinates, reference_platform)

        print "%s" % test_system_name
        for platform_index in range(openmm.Platform.getNumPlatforms()):
            platform = openmm.Platform.getPlatform(platform_index)
            platform_name = platform.getName()            
            platform_potential = compute_potential(system, coordinates, platform)
            error = platform_potential - reference_potential

            print "%32s %16.3f kcal/mol %16.3f kcal/mol" % (platform_name, platform_potential / units.kilocalories_per_mole, error / units.kilocalories_per_mole)
                                                                                        

    
