#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test all testsystems on different platforms to make sure errors in potential energy and forces are small.

DESCRIPTION

COPYRIGHT

@author John D. Chodera <jchodera@gmail.com>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

TODO

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import os.path
import sys
import math

import simtk.unit as units
import simtk.chem.openmm as openmm
import simtk.chem.openmm.extras.amber as amber
import simtk.chem.openmm.extras.testsystems as testsystems

#=============================================================================================
# MAIN AND TESTS
#=============================================================================================

if __name__ == "__main__":
    # Parameters
    prmtop_filename = 'system.prmtop'
    crd_filename = 'system.crd'

    # Create system.
    print "Reading system..."
    cutoff = 9.0 * units.angstroms
    system = amber.readAmberSystem(prmtop_filename, nonbondedMethod='reaction-field', nonbondedCutoff=cutoff, shake='h-bonds')
    [coordinates, box_vectors] = amber.readAmberCoordinates(crd_filename, read_box=True)
    system.setPeriodicBoxVectors(box_vectors[0], box_vectors[1], box_vectors[2])

    # DEBUG
    outfile = open('test.dat', 'w')
    [a,b,c] = system.getPeriodicBoxVectors()
    outfile.write("%12.6f %12.6f %12.6f\n" % (a[0] / units.nanometers, a[1] / units.nanometers, a[2] / units.nanometers))
    outfile.write("%12.6f %12.6f %12.6f\n" % (b[0] / units.nanometers, b[1] / units.nanometers, b[2] / units.nanometers))
    outfile.write("%12.6f %12.6f %12.6f\n" % (c[0] / units.nanometers, c[1] / units.nanometers, c[2] / units.nanometers))

    outfile.write("%d\n" % system.getNumParticles());
    nonbondedForce = system.getForce(3)    
    for atom_index in range(system.getNumParticles()):
        mass = system.getParticleMass(atom_index) / units.amu
        x = coordinates[atom_index,0] / units.nanometers
        y = coordinates[atom_index,1] / units.nanometers
        z = coordinates[atom_index,2] / units.nanometers
        [charge, sigma, epsilon] = nonbondedForce.getParticleParameters(atom_index)
        charge /= units.elementary_charge
        sigma /= units.nanometers
        epsilon /= units.kilojoules_per_mole
        outfile.write("%8d %8.3f %24.12f %24.12f %24.12f %24.12f %24.12f %24.12f\n" % (atom_index, mass, x, y, z, charge, sigma, epsilon))
    outfile.write("%d\n" % nonbondedForce.getNumExceptions())
    for exception_index in range(nonbondedForce.getNumExceptions()):
        [i, j, chargeprod, sigma, epsilon] = nonbondedForce.getExceptionParameters(exception_index)
        chargeprod /= units.elementary_charge**2
        sigma /= units.nanometers
        epsilon /= units.kilojoules_per_mole
        outfile.write("%8d %8d %24.12f %24.12f %24.12f\n" % (i, j, chargeprod, sigma, epsilon))

    outfile.write("%d\n" % system.getNumConstraints());
    for constraint_index in range(system.getNumConstraints()):
        [i, j, distance] = system.getConstraintParameters(constraint_index)
        outfile.write('%8d %8d %12.6f\n' % (i, j, distance / units.nanometers))

    outfile.close()
    sys.exit(1)
