#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test of stripping units from parameters passed to Custom*Force methods.

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

This code is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import numpy
import math

import simtk
import simtk.chem.openmm as openmm
import simtk.unit as units

#=============================================================================================
# Array of harmonic oscillators
#=============================================================================================

def HarmonicOscillatorArray(K=None, d=None, mass=None, N=None, mm=None):
    """
    Create a 1D array of noninteracting particles in 3D harmonic oscillator wells.

    @keyword K: harmonic restraining potential (default: 1.0 * units.kilojoules_per_mole/units.nanometer**2)    
    @keyword d: distance between harmonic oscillators (default: 1.0 * units.nanometer)
    @keyword mass: particle mass (default: 39.948 * units.amu)

    @return system: the system
    @type system: a System object

    @return coordinates: initial coordinates for the system
    @type coordinates: a Coordinates object

    EXAMPLES

    Create a constraint-coupled 3D harmonic oscillator with default parameters.

    >>> [system, coordinates] = HarmonicOscillatorArray()

    Create a constraint-coupled harmonic oscillator with specified mass, distance, and spring constant.

    >>> mass = 12.0 * units.amu
    >>> d = 5.0 * units.angstroms
    >>> K = 1.0 * units.kilocalories_per_mole / units.angstroms**2
    >>> N = 10 # number of oscillators
    >>> [system, coordinates] = HarmonicOscillatorArray(K=K, d=d, mass=mass, N=N)

    Test the energy

    >>> # Create a Context.
    >>> temperature = 298.0 * units.kelvin
    >>> collision_rate = 90.0 / units.picosecond
    >>> timestep = 1.0 * units.femtosecond    
    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> context = openmm.Context(system, integrator, platform)
    >>> # Set positions
    >>> context.setPositions(coordinates)
    >>> # Evaluate the potential energy.
    >>> state = context.getState(getEnergy=True)
    >>> print state.getPotentialEnergy().in_units_of(units.kilocalories_per_mole)
    0.0 kcal/mol
    
    Integrate dynamics

    >>> nsteps = 100 # number of steps to integrate
    >>> integrator.step(nsteps)
    >>> # Retrieve configuration to make sure no coordinates are nan
    >>> state = context.getState(getPositions=True)
    >>> coordinates = state.getPositions(asNumpy=True)
    >>> if numpy.any(numpy.isnan(coordinates / units.nanometers)): raise Exception('some coordinates are nan after integration: %s' % str(coordinates))

    """

    # Use pyOpenMM by default.
    if mm is None:
        mm = simtk.chem.openmm

    # Default parameters
    if K       is None:  K           = 1.0 * units.kilojoules_per_mole / units.nanometer**2
    if d       is None:  d           = 1.0 * units.nanometer
    if mass    is None:  mass        = 39.948 * units.amu 
    if N       is None:  N           = 5 

    # Create an empty system object.
    system = mm.System()

    # Add particles to the system.
    for n in range(N):
        system.addParticle(mass)

    # Set the coordinates for a 1D array of particles spaced d apart along the x-axis.
    coordinates = units.Quantity(numpy.zeros([N,3], numpy.float32), units.angstroms)
    for n in range(N):
        coordinates[n,0] = n*d

    # Add a restrining potential for each oscillator.
    force = mm.CustomExternalForce('(K/2.0) * ((x-x0)^2 + y^2 + z^2)')
    force.addPerParticleParameter('K')
    force.addPerParticleParameter('x0')    
    for n in range(N):
        force.addParticle(n, [K, d*n])
    system.addForce(force)

    return (system, coordinates)


#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    import doctest

    # Test all systems on Reference platform.
    platform = openmm.Platform.getPlatformByName("Reference")
    doctest.testmod()    

