#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test energy minimization on all test systems and all platforms.

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import sys
import math
import copy
import time

import numpy

import simtk.unit as units
import simtk.openmm as openmm
import simtk.pyopenmm.extras.testsystems as testsystems

import doctest

#=============================================================================================
# MAIN
#=============================================================================================

verbose = True

# Parameters for LocalEnergyMinimizer.minimize
tolerance = 1.0 * units.kilojoules / units.nanometers**2 # minmization terminated once RMS force falls below tolerance
maximum_evaluations = 10 # maximum number of energy evaluations

# Select fastest platform.
nplatforms = openmm.Platform.getNumPlatforms()
platform_speeds = numpy.zeros([nplatforms], numpy.float64)
for platform_index in range(nplatforms):
    platform = openmm.Platform.getPlatform(platform_index)
    platform_speeds[platform_index] = platform.getSpeed()
platform_index = int(numpy.argmax(platform_speeds))
platform = openmm.Platform.getPlatform(platform_index)

if verbose: print 'Testing OpenMM platform "%s"' % platform.getName()

# Set list of msystems to test.
test_systems = [ (name, getattr(testsystems, name)) for name in dir(testsystems) if callable(getattr(testsystems, name)) ] # Test all available test systems

#test_system_names = ['SodiumChlorideCrystal']
#test_systems = [ (name, getattr(testsystems, name)) for name in test_system_names ] # Test only some systems

# Test optimizers on multiple systems to check robustness.
all_tests_passed = True
for (system_name, system_constructor) in test_systems:
    #print "*******************"
    #print "Testing LocalEnergyMinimizer on system '%s' on platform '%s'..." % (system_name, platform_name)
    
    try:
        # Create system.
        [system, coordinates] = system_constructor()
        
        # Create a Context.
        timestep = 1.0 * units.femtoseconds
        integrator = openmm.VerletIntegrator(timestep)
        context = openmm.Context(system, integrator, platform)

        # Set coordinates.
        context.setPositions(coordinates)

        # Compute initial energy.
        state = context.getState(getEnergy=True)
        initial_potential = state.getPotentialEnergy()
        
        # Minimize.
        openmm.LocalEnergyMinimizer.minimize(context, tolerance, maximum_evaluations)    

        # Compute final energy.
        state = context.getState(getEnergy=True)
        final_potential = state.getPotentialEnergy()

        # Report difference.
        if verbose:
            print "results after %d iterations" % maximum_evaluations
            print "initial potential : %s" % (str(initial_potential))
            print "final potential   : %s" % (str(final_potential))

        # Send an error if energy went up.
        if (final_potential > initial_potential):
            # Report difference.
            print "initial potential : %s" % (str(initial_potential))
            print "final potential   : %s" % (str(final_potential))            
            raise Exception("Energy increased on minimization.")

    except Exception as e:
        
        print "**** FAILED on system %s platform %s ****" % (system_name, platform_name)
        print e        
        all_tests_passed = False

#=============================================================================================
# Report pass or fail in exit code
#=============================================================================================

if all_tests_passed:
   sys.exit(0)
else:
   sys.exit(1)

   
