#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Energy minimization algortihms.

DESCRIPTION

This module provides a facility for energy minimization algorithms that use
iterative calls to OpenMM to compute the force.  As such, it will not be as fast
as optimization methods implemented as OpenMM integrators, but it will be more
flexible.

WARNING

Note that, because there is currently no facility in OpenMM to SHAKE bond constraints
into compliance, constraints are not respected during minimization.  In order to maintain
reasonable geometries, your system must contain bonded terms in addition to constraints.

IMPLEMENTATION NOTES

A Minimizer base class provided utility functions for packing and unpacking the
Quantity-wrapped Nx3 numpy vector of atomic positions into a dimensionless
vector, as well as evaluation of dimensionless forms of the energy and gradient
to use as objective functions for optimization.  The derived casses extend this
base class and utilize different optimizers from the scipy.optimize package.

TODO

* Add support for maintaining constraints in systems with constraints.
* Extend support for atomic positions specified as lists.

COPYRIGHT AND LICENSE

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import sys
import math
import numpy
import copy
import time

import simtk.chem.openmm as openmm
import simtk.unit as units

import numpy
import scipy.optimize

#=============================================================================================
# REVISION CONTROL
#=============================================================================================

__version__ = "$Id: optimize.py 487 2009-12-21 23:03:54Z jchodera $" # 

#=============================================================================================
# MODULE CONSTANTS
#=============================================================================================

#=============================================================================================
# Exceptions
#=============================================================================================

class NotImplementedException(Exception):
    """
    Exception denoting that the requested feature has not yet been implemented.

    """

class ParameterException(Exception):
    """
    Exception denoting that a parameter has been incorrectly furnished.

    """

#=============================================================================================
# Minimizer base class
#=============================================================================================

class Minimizer(object):
    
    def __init__(self, system, platform=None):
        """
        Initialize a minimization object.

        ARGUMENTS

        system (simtk.chem.openmm.System) - the system whose energy is to be minimized

        OPTIONAL ARGUMENTS

        platform (simtk.chem.openmm.Platform) - if provided, this Platform object will be used (default: None)
             otherwise, OpenMM will automatically select the Platform to use

        NOTES

        For efficiency, a copy of the System object and an OpenMM Context object
        will be created and held until this class is destroyed

        INTERNAL TESTS

        >>> # Create a Minimizer object.
        >>> minimizer = Minimizer(system, platform=platform)
        >>> # Pack coordinates.
        >>> x = minimizer._pack_coordinates(coordinates)
        >>> # Unpack coordinates.
        >>> unpacked_coordinates = minimizer._unpack_coordinates(x)
        >>> # Compute objective function.
        >>> objective = minimizer._objective(x)
        >>> # Compute gradient of objective function.
        >>> gradient = minimizer._gradient(x)
        >>> # Clean up.
        >>> del minimizer
        
        """

        # Set default output frequency.
        self.output_frequency = -1 # no output
        self.best_objective = None
        self.best_x = None
        
        # Create a copy of the system to be minimized.
        # self.system = copy.deepcopy(system) # TODO: Make this work
        self.system = system # HACK for now - this could be unsafe if user changes system object after initialization
        self.natoms = system.getNumParticles()

        # Initialize an integrator with arbitrary parameters.
        temperature = 1.0 * units.kelvin
        collision_rate = 90.0 / units.picosecond
        timestep = 1.0e-6 * units.femtosecond
        self.integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)        

        # Create an OpenMM Context and cache it for efficiency.            
        if platform is None:
            # Allow OpenMM to choose best platform.
            self.context = openmm.Context(self.system, self.integrator)
        else:
            # Create Context using specified platform.
            self.context = openmm.Context(self.system, self.integrator, platform)
            
        # Create fundamental units for expressing quantities during optimization.
        self.length_unit = units.nanometers
        self.energy_unit = units.kilojoules_per_mole
        
        return

    def _apply_constraints(self, x):
        """
        Apply SHAKE constraints by taking a step of dynamics with zero timestep.

        This is a hack, since there is currently no other way to do this in OpenMM.  

        WARNING: This is experimental.

        """
        
        # Construct unit-bearing coordinates.
        coordinates = self._unpack_coordinates(x)

        # Push coordinates to Context.
        self.context.setPositions(coordinates)

        # Take a step of dynamics with zero timestep.
        self.integrator.step(1)

        # Retrieve coordinates.
        state = self.context.getState(getPositions=True)
        coordinates = state.getPositions(asNumpy=True)

        # Pack coordinates.
        x = self._pack_coordinates(coordinates)

        return x
        
    def _output_callback(self, x, force_display=False):
        """
        Write output at desired frequency.

        """

        # Increment iteration count.
        self.iteration_count += 1

        # Output energy and gradient norm, if specified.
        if force_display or ((self.output_frequency > 0) and (self.iteration_count % self.output_frequency == 0)):
            # Construct unit-bearing coordinates.
            coordinates = self._unpack_coordinates(self.best_x)

            # Compute energy and force.
            state = self.context.getState(getEnergy=True, getForces=True)
            energy = state.getPotentialEnergy()
            force = state.getForces(asNumpy=True) 

            # Compute gradient norm
            force_unit = self.energy_unit / self.length_unit
            gnorm = numpy.sqrt(((force / force_unit)**2).sum()) * force_unit # TODO: Simplify when .sum() is implemented for Quantitty

            def format(x):
                string = ""
                y = x / x.unit
                if (abs(y) < 1.0e6):
                    string += "%12.3f" % y
                else:
                    string += "%14.3e" % y
                string += " " + str(x.unit)
                return string
            
            print "iteration %8d : potential %30s, gnorm %30s" % (self.iteration_count, format(energy), format(gnorm))

        return

    def _pack_coordinates(self, coordinates):
        """
        Pack coordinates object into a (3N)-vector.

        ARGUMENTS

        coordinates (simtk.unit.Quantity with length units wrapping an Nx3 numpy array)

        RETURNS

        x (3N numpy array without units)

        """

        x = (coordinates / self.length_unit).reshape([self.natoms*3])

        return x

    def _unpack_coordinates(self, x):
        """
        Unpack coordinates from a (3N)-vector into unit-bearing [N,3]-vector.

        """

        coordinates = units.Quantity(x.reshape([self.natoms,3]), self.length_unit)
        
        return coordinates
    
    def _objective(self, x):
        """
        Compute objective function.

        ARGUMENTS
        
        x (3N-vector) - packed unitless coordinates
        
        RETURNS

        objective (float) - potential energy in kJ/mol

        """

        # Construct unit-bearing coordinates.
        coordinates = self._unpack_coordinates(x)

        # Push coordinates to Context.
        self.context.setPositions(coordinates)

        # Compute potential energy.
        state = self.context.getState(getEnergy=True)
        potential = state.getPotentialEnergy()
        #print "%16.3f" % (potential / units.kilojoules_per_mole)
        
        # Compute dimensionless objective.
        objective = potential / self.energy_unit

        # Clean up
        del coordinates, state

        # Update best objective found so far.
        if (self.best_objective is None) or (objective < self.best_objective):
            self.best_objective = objective
            self.best_x = x.copy()

        # Print output if requested.
        self._output_callback(x)

        return objective

    def _gradient(self, x):
        """
        Compute gradient of objective function.

        ARGUMENTS

        x (natoms x 3 numpy array - no units) - coordinates in nanometers

        RETURNS

        gradient (natoms x 3 numpy array - no units) - gradient in kJ/mol/nm

        """

        # Construct unit-bearing coordinates.
        coordinates = self._unpack_coordinates(x)

        # Push coordinates to Context.
        self.context.setPositions(coordinates)

        # Compute force.
        state = self.context.getState(getForces=True)
        force = state.getForces(asNumpy=True)

        # Compute dimensionless gradient.
        gradient = - (force / self.energy_unit * self.length_unit).reshape([self.natoms*3])

        # Clean up
        del coordinates, state, force
        
        return gradient

#=============================================================================================
# Simplex minimizer
#=============================================================================================

class SimplexMinimizer(Minimizer):
    """
    Simplex minimizer.

    EXAMPLES
    
    >>> # Initialize a gradient descent minimizer with default options.
    >>> minimizer = SimplexMinimizer(system, verbose=True, platform=platform)
    >>> # Minimize the initial coordinates.
    >>> minimized_coordinates = minimizer.minimize(coordinates)
    >>> # Clean up to release the Context.
    >>> del minimizer

    """

    def __init__(self, system, verbose=False, platform=None, maximum_evaluations=1000, relative_energy_decrease_tolerance=1.0e-6, relative_coordinate_change_tolerance=1.0e-6):
        """
        Initialize a simplex minimizer.

        ARGUMENTS

        @param system                  the system whose energy is to be minimized

        OPTIONAL ARGUMENTS

        @param maximum_evaluations     maximum number of allowd function evaluations
        @param relative_energy_decrease_tolerance    relative error in energy for convergence

        NOTES

        A copy of the System object and an OpenMM Context object will be created and held until this class is destroyed.

        """

        # Initialize Minimizer base class.
        Minimizer.__init__(self, system, platform=platform)

        # Store parameters.
        self.verbose = verbose
        self.maximum_energy_evaluations = maximum_evaluations
        self.relative_energy_decrease_tolerance = relative_energy_decrease_tolerance
        self.relative_coordinate_change_tolerance = relative_coordinate_change_tolerance

        return


    def minimize(self, coordinates):
        """
        Perform a simplex minimization starting from the given coordinates.
        
        WARNING
        
        Note that, because there is currently no facility in OpenMM to SHAKE bond constraints
        into compliance, constraints are not respected during minimization.  In order to maintain
        reasonable geometries, your system must contain bonded terms in addition to constraints.

        ARGUMENTS

        @param coordinates       the coordinates to be minimized
        @paramtype coordinates   numpy array with units

        RETURNS

        coordinates - updated coordinates
        potential - potential at updated coordinates

        """
        
        # Create dimensionless initial parameter vector.
        x0 = self._pack_coordinates(coordinates)

        # Minimize.
        self.iteration_count = 0 # reset iteration count
        xopt = scipy.optimize.fmin(self._objective, x0, ftol=self.relative_energy_decrease_tolerance, xtol=self.relative_coordinate_change_tolerance, maxfun=self.maximum_energy_evaluations)
        if self.verbose: self._output_callback(xopt, force_display = True)

        # Re-associate units.
        coordinates = self._unpack_coordinates(xopt)
        
        return coordinates

#=============================================================================================
# Conjugate gradients minimizer
#=============================================================================================

class ConjugateGradientsMinimizer(Minimizer):
    """
    Conjugate gradients minimizer.
    
    EXAMPLES
    
    >>> # Initialize minimizer.
    >>> minimizer = ConjugateGradientsMinimizer(system, verbose=True, platform=platform)
    >>> # Minimize the initial coordinates.
    >>> minimized_coordinates = minimizer.minimize(coordinates)
    >>> # Clean up to release the Context.
    >>> del minimizer
    
    """

    def __init__(self, system, verbose=False, platform=None, maximum_evaluations=1000, gradient_convergence_tolerance=None):
        """
        Initialize a conjugate gradients minimizer.

        ARGUMENTS

        @param system                  the system whose energy is to be minimized

        OPTIONAL ARGUMENTS

        @param maximum_evaluations     maximum number of allowd function evaluations
        @param relative_energy_decrease_tolerance    relative error in energy for convergence

        NOTES

        A copy of the System object and an OpenMM Context object will be created and held until this class is destroyed.

        """

        # Initialize Minimizer base class.
        Minimizer.__init__(self, system, platform=platform)

        # Store parameters.
        self.verbose = verbose
        self.maximum_energy_evaluations = maximum_evaluations
        self.gradient_convergence_tolerance = 0.0
        if gradient_convergence_tolerance is not None:
            self.gradient_convergence_tolerance = gradient_convergence_tolerance * self.length_unit / self.energy_unit
        
        return

    def minimize(self, coordinates):
        """
        Perform a simplex minimization starting from the given coordinates.

        ARGUMENTS

        @param coordinates       the coordinates to be minimized
        @paramtype coordinates   numpy array with units

        RETURNS

        coordinates - updated coordinates
        potential - potential at updated coordinates
        
        WARNING
        
        Note that, because there is currently no facility in OpenMM to SHAKE bond constraints
        into compliance, constraints are not respected during minimization.  In order to maintain
        reasonable geometries, your system must contain bonded terms in addition to constraints.

        """
        
        # Create dimensionless initial parameter vector.
        x0 = self._pack_coordinates(coordinates)

        # Minimize.
        self.iteration_count = 0 # reset iteration count
        xopt = scipy.optimize.fmin_cg(self._objective, x0, fprime=self._gradient, maxiter=self.maximum_energy_evaluations, gtol=self.gradient_convergence_tolerance, disp=0)
        if self.verbose: self._output_callback(xopt, force_display = True)
        
        # Re-associate units.
        coordinates = self._unpack_coordinates(xopt)
        
        return coordinates

#=============================================================================================
# L-BFGS minimizer
#=============================================================================================

class LBFGSMinimizer(Minimizer):
    """
    L-BFGS (limited-memory BFGS) minimizer.

    """

    def __init__(self, system, verbose=False, platform=None, maximum_evaluations=100, m=5):
        """
        Initialize a L-BFGS minimizer.

        ARGUMENTS

        @param system                  the system whose energy is to be minimized

        OPTIONAL ARGUMENTS

        @param maximum_evaluations     maximum number of allowed function evaluations
        @param relative_energy_decrease_tolerance    relative error in energy for convergence

        NOTES

        A copy of the System object and an OpenMM Context object will be created and held until this class is destroyed.

        """

        # Initialize Minimizer base class.
        Minimizer.__init__(self, system, platform=platform)

        # Store parameters.
        self.verbose = verbose
        self.maximum_energy_evaluations = maximum_evaluations
        self.m = m # number of gradients to store
        
        return

    def minimize(self, coordinates, constrain=False):
        """
        Perform a simplex minimization starting from the given coordinates.

        ARGUMENTS

        @param coordinates       the coordinates to be minimized
        @paramtype coordinates   numpy array with units

        RETURNS

        coordinates - updated coordinates
        potential - potential at updated coordinates
        
        WARNING
        
        Note that, because there is currently no facility in OpenMM to SHAKE bond constraints
        into compliance, constraints are not respected during minimization.  In order to maintain
        reasonable geometries, your system must contain bonded terms in addition to constraints.

        """
        
        # Create dimensionless initial parameter vector.
        x0 = self._pack_coordinates(coordinates)

        if constrain:
            # EXPERIMENTAL!
            self.iteration_count = 0
            maxits = 5
            xopt = self._apply_constraints(x0)            
            for iteration in range(maxits):
                [xopt, fopt, info] = scipy.optimize.fmin_l_bfgs_b(self._objective, xopt, fprime=self._gradient, m=self.m, maxfun=self.maximum_energy_evaluations)
                xopt = self._apply_constraints(xopt)                
                if self.verbose: self._output_callback(xopt, force_display = True)
        else:
            self.iteration_count = 0 # reset iteration count
            [xopt, fopt, info] = scipy.optimize.fmin_l_bfgs_b(self._objective, x0, fprime=self._gradient, m=self.m, maxfun=self.maximum_energy_evaluations)
            if self.verbose: self._output_callback(xopt, force_display = True)
        
        # Re-associate units.
        coordinates = self._unpack_coordinates(xopt)
        
        return coordinates

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":    
    import doctest

    # DEBUG: Test only Cuda for now.
    platform = openmm.Platform.getPlatformByName('Cuda')
    print 'Testing OpenMM platform "%s"' % platform.getName()
    # Test all available test systems
    import testsystems
    test_systems = [ (name, getattr(testsystems, name)) for name in dir(testsystems) if callable(getattr(testsystems, name)) ]
    for (name, test_system) in test_systems:
        print "*******************"
        print "Constructing system '%s'..." % name
        [system, coordinates] = test_system()

        print "L-BFGS"
        # Initialize a minimizer with default options.
        minimizer = LBFGSMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        minimized_coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

        print "ConjugateGradient"
        # Initialize a minimizer with default options.
        minimizer = ConjugateGradientsMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        minimized_coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer

        print "Simplex"
        # Initialize a minimizer with default options.
        minimizer = SimplexMinimizer(system, verbose=True, platform=platform)
        # Minimize the initial coordinates.
        minimized_coordinates = minimizer.minimize(coordinates)
        # Clean up to release the Context.
        del minimizer
    
    sys.exit(0)
    
    # Test all available platforms.
    for platform_index in range(openmm.Platform.getNumPlatforms()):
        platform = openmm.Platform.getPlatform(platform_index)
        print 'Testing OpenMM platform "%s"' % platform.getName()
        # Test all available test systems
        import testsystems
        test_systems = [ (name, getattr(testsystems, name)) for name in dir(testsystems) if callable(getattr(testsystems, name)) ]
        for (name, test_system) in test_systems:
            print name
            [system, coordinates] = test_system()
            
            # Run doctests on this test system and platform.
            doctest.testmod()

            
