#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test that force and energy gradient agree

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author Mark Friedrichs <friedrim@stanford.edu>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import os
import traceback
import os.path
import math
import doctest
import time

import simtk.unit as units
import simtk.openmm as openmm

from ValidationUtilities      import ValidationUtilities             
from ParameterFileParser      import ValidationParameterFileParser
from OpenMMSystemParameters   import OpenMMSystemParameters

import EnergyGradientTests

#=============================================================================================
# Print usage (command-line arguments)
#=============================================================================================

def printUsage( ):
    """Print command line arguments

    ARGUMENTS

    """

    firstTab    = 1
    secondTab   = 19
    thirdTab    = 60
    print "Command line arguments:\n"
    outputString  = ' '.rjust(firstTab) + "-platform".rjust(secondTab) + " platform (Reference, Cuda, OpenCL) to run tests on\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-controlFile".rjust(secondTab) + " Control file name\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-systemsDirectory".rjust(secondTab) + " Systems file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-configDirectory".rjust(secondTab) + " Config file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-dataDirectory".rjust(secondTab) + " Data file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-serializeDirectory".rjust(secondTab) + " Serialize file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-deviceId".rjust(secondTab) + " device id for gpu\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-verbose".rjust(secondTab) + " if nonzero, then be verbose\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-printParam".rjust(secondTab) + " if nonzero, then be print parameters and exit\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-sleep ".rjust(secondTab) + " if nonzero, then sleep the specified time between tests\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-testSummaryFileName".rjust(secondTab) + " summary file name\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-help".rjust(secondTab) + " current output\n".rjust(thirdTab)
    outputString += "\n\n"
    print outputString
    return 

# parse command-line arguments

def parseCommandLine( argumentHash ):

    argLen   = len( sys.argv )
    argLenM1 = argLen - 1
    skip     = 0
    verbose  = argumentHash['Verbose']
    for ii in range( 1,argLen ):  
        if( skip ): 
           skip = 0
           continue
        print "arg %d %s" % ( ii, sys.argv[ii] )
        if( sys.argv[ii].startswith('-platform') and ii < argLenM1 ):
            skip = 1
            argumentHash['Platform'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-systemsDirectory') and ii < argLenM1 ):
            skip = 1
            systemsDirectory = sys.argv[ii+1]
            argumentHash['SystemsDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-configDirectory') and ii < argLenM1 ):
            skip = 1
            argumentHash['ConfigDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-dataDirectory') and ii < argLenM1 ):
            skip = 1
            argumentHash['DataDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-serializeDirectory') and ii < argLenM1 ):
            skip = 1 
            serializeDirectory = sys.argv[ii+1]
            argumentHash['SerializeDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-sleep') and ii < argLenM1 ):
            skip = 1
            argumentHash['Sleep'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-deviceId') and ii < argLenM1 ):
            skip = 1
            argumentHash['DeviceId'] = int( sys.argv[ii+1] )

        elif( sys.argv[ii].startswith('-verbose') and ii < argLenM1 ):
            skip = 1
            if( int( sys.argv[ii+1] ) > 0 ):
                 argumentHash['Verbose']  = True 
            else:
                 argumentHash['Verbose']  = False

        elif( sys.argv[ii].startswith('-printParam') and ii < argLenM1 ):
            skip = 1
            if( int( sys.argv[ii+1] ) > 0 ):
                 argumentHash['PrintParam']  = True 
            else:
                 argumentHash['PrintParam']  = False

        elif( sys.argv[ii].startswith('-testSummaryFileName') and ii < argLenM1 ):
            skip = 1 
            argumentHash['TestSummaryFileName'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-controlFile') and ii < argLenM1 ):
            skip = 1
            argumentHash['ControlFileName']  = sys.argv[ii+1]
        elif( sys.argv[ii].startswith('-help') ):
            printUsage()
            exit(-1)
        else:
            print "Argument %3d %s not recognized." % (ii,sys.argv[ii])
            sys.stdout.flush()
            printUsage()
            exit(-1)

        if( verbose ):
            print "%3d %s" % (ii,sys.argv[ii])
    return

#=============================================================================================
# Print a summary of the results to a file
#=============================================================================================

def printTestSummary( testSummaryFileName, testNameHash, testResults ):

    """Print summary of results

    ARGUMENTS
        testSummaryFileName  (string) summary file name
        testResults          (hash)   testResults[testName] = testResult (ForceEnergyGradientResult object)

    """

    firstTab    = 40
    summaryFile = open( testSummaryFileName, 'w')
    for testName,dummyArg in testNameHash.iteritems():
        headerString = testName.rjust(firstTab) + " Passed    RelativeDelta"
        summaryFile.write( headerString + "\n\n" )
        count        = 0
        average      = 0.0
        max          = -1.0
        for fullTestName, result in testResults.iteritems():
            if( result.getTestName() == testName ):
                passed         = result.testPassed( )
                outputString   = result.getPlatformName().rjust(8) + result.getSystemName().rjust(firstTab) + " " + repr(passed).rjust(6) + " " + ("%10.3e" % result.getRelativeDelta()).rjust(21) + " Delta " + ("%10.3e" % result.getForceStepDelta()) + " Tol " + ("%10.3e" % result.getTolerance())
                average       += math.log( result.getRelativeDelta() )
                count         += 1
                if( result.getRelativeDelta() > max ):
                    max = result.getRelativeDelta()
 
                summaryFile.write( outputString + "\n" )
        if( count > 0 ): average = average/count
        average        = math.exp( average )
        outputString   = "\n%40s Max %10.3e  Log Average %10.3e  Count %d\n" % (testName, max, average, count)
        summaryFile.write( outputString + "\n\n\n" )

    summaryFile.close()

    return

#=============================================================================================
# Build tests objects given name and args
#=============================================================================================

def forceEnergyGradientTestFactory( testName, systemParameters, testParameterHash, verbose ):

    """Build test

    ARGUMENTS
        test name                         test name
        systemParameters                  systemParameters
        testParameterHash                 testParameterHash
        verbose                           if true, be verbose

    RETURN
        if testName is valid, return test object; else return None
    """
    if( testName == 'HarmonicBondForceEnergyGradientTest' ):
        forceTest   =  EnergyGradientTests.HarmonicBondForceEnergyGradientTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'HarmonicAngleForceEnergyGradientTest' ):
        forceTest   =  EnergyGradientTests.HarmonicAngleForceEnergyGradientTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'PeriodicTorsionForceEnergyGradientTest' ):
        forceTest   =  EnergyGradientTests.PeriodicTorsionForceEnergyGradientTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'RBTorsionForceEnergyGradientTest' ):
        forceTest   =  EnergyGradientTests.RBTorsionForceEnergyGradientTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'NonbondedForceEnergyGradientNoCutoffTest' ):
        forceTest   =  EnergyGradientTests.NonbondedForceEnergyGradientNoCutoffTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'NonbondedForceEnergyGradientCutoffNonPeriodicTest' ):
        forceTest   =  EnergyGradientTests.NonbondedForceEnergyGradientCutoffNonPeriodicTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'NonbondedForceEnergyGradientCutoffPeriodicTest' ):
        forceTest   =  EnergyGradientTests.NonbondedForceEnergyGradientCutoffPeriodicTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'NonbondedForceEnergyGradientEwaldTest' ):
        forceTest   =  EnergyGradientTests.NonbondedForceEnergyGradientEwaldTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'NonbondedForceEnergyGradientPMETest' ):
        forceTest   =  EnergyGradientTests.NonbondedForceEnergyGradientPMETest( systemParameters, testParameterHash, verbose )
    elif( testName == 'GbsaObcForceEnergyGradientNoCutoffTest' ):
        forceTest   =  EnergyGradientTests.GbsaObcForceEnergyGradientNoCutoffTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'GbsaObcForceEnergyGradientCutoffNonPeriodicTest' ):
        forceTest   =  EnergyGradientTests.GbsaObcForceEnergyGradientCutoffNonPeriodicTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'GbsaObcForceEnergyGradientCutoffPeriodicTest' ):
        forceTest   =  EnergyGradientTests.GbsaObcForceEnergyGradientCutoffPeriodicTest( systemParameters, testParameterHash, verbose )
    elif( testName == 'GbviForceEnergyGradientNoCutoffTest' ):
        forceTest   =  EnergyGradientTests.GbviForceEnergyGradientNoCutoffTest( systemParameters, testParameterHash, verbose )
    else:
        forceTest   = None
        print "%s test not recognized." % testName

    return forceTest

#=============================================================================================
# MAIN
#=============================================================================================

# set Validation directory

argumentHash                             = {}
validationUtilities                      = ValidationUtilities(False)

validationDirectory                      = os.path.join(os.getenv('PYOPENMM_SOURCE_DIR'), 'test', 'validation')
useXml                                   = 1
if( useXml ):
    parameterFileSuffix                  = '.xml'
    systemsDirectory                     = os.path.join(validationDirectory, 'systems_xml')
else:
    parameterFileSuffix                  = '.txt'
    systemsDirectory                         = os.path.join(validationDirectory, 'systems')

configDirectory                          = os.path.join(validationDirectory, 'config')
dataDirectory                            = os.path.join(validationDirectory, 'data')

testListFileName                         = 'ForceEnergyGradientTestNames.txt'
systemListFileName                       = 'SystemNames.txt'

argumentHash['ValidationDirectory']      = validationDirectory
argumentHash['ConfigDirectory']          = configDirectory
argumentHash['DataDirectory']            = dataDirectory
argumentHash['TestListFileName']         = testListFileName
argumentHash['SystemListFileName']       = systemListFileName
argumentHash['Platform']                 = 'Reference'
argumentHash['SystemsDirectory']         = systemsDirectory
argumentHash['PrintParam']               = False
argumentHash['Verbose']                  = False
argumentHash['Sleep']                    = 0

# parse command line

parseCommandLine( argumentHash )
systemsDirectory                         = argumentHash['SystemsDirectory']

if( 'SerializeDirectory' in argumentHash ):
    serializeDirectory                   = argumentHash['SerializeDirectory']
else:
    serializeDirectory                   = 0

# track number of tests that do not run

numberOfTestsThatDidNotRun               = 0

defaultParameterHash                     = { 
                                               'Active'                     :       1,
                                               'Tolerance'                  :       4.0e-02,
                                               'ForceStepDelta'             :       1.0e-03,
                                           }

printParam                               = argumentHash['PrintParam']
verbose                                  = argumentHash['Verbose']
validationUtilities.setVerbose( verbose )

platform                                 = argumentHash['Platform']

fullTestListFileName                     = os.path.join( argumentHash['DataDirectory'], argumentHash['TestListFileName'] )
testList                                 = validationUtilities.getListFromFile( fullTestListFileName )

fullSystemListFileName                   = os.path.join( argumentHash['DataDirectory'], argumentHash['SystemListFileName'] )
systemList                               = validationUtilities.getListFromFile( fullSystemListFileName )

testHash                                 = validationUtilities.buildDefaultTestHash( testList, systemList, defaultParameterHash )
if( 'ControlFileName' in argumentHash ):
    fullControlFileName                  = os.path.join( argumentHash['ConfigDirectory'], argumentHash['ControlFileName'] )
    validationUtilities.editTestHashBasedOnControlFile( fullControlFileName, testHash )

if( argumentHash['TestSummaryFileName'] != 'NA' ):
    testSummaryFileName                      = argumentHash['TestSummaryFileName']
else:
    testSummaryFileName = platform + 'ForceEnergyComparisonSummary.txt'

if( verbose or printParam ):
    print "%s" % validationUtilities.printHash( argumentHash, "Arguments\n" )
    print "%s" % validationUtilities.printTestHash( testHash, 1 )
    if( printParam ):exit(0)

testResults                              = {}
testNameHash                             = {}

# loop over tests

for fullTestName, testParameterHash in testHash.iteritems():
    if( testParameterHash['Active'] > 0 ):

        systemName              = testParameterHash['SystemName']
        testName                = testParameterHash['TestName']
        testNameHash[testName]  = 1

        print '*' * 80
        print "%30s %30s" % (systemName, testName )
    
        # read parameter file.
    
        parameterFileName        = systemName + parameterFileSuffix
        fullParameterFileName    = os.path.join(systemsDirectory, parameterFileName)
        if( useXml ):
            positionFileName         = systemName + '.txt'
            fullPositionFileName     = os.path.join(systemsDirectory, positionFileName)
            systemParameters         = OpenMMSystemParameters(fullParameterFileName, fullPositionFileName, verbose)
        else:
            systemParameters         = ValidationParameterFileParser(fullParameterFileName, verbose)
     
        # loop over tests

        try:

            if( 'DeviceId' in argumentHash ):
                testParameterHash['DeviceId'] = argumentHash['DeviceId']

            # get test and run

            forceTest   = forceEnergyGradientTestFactory( testName, systemParameters, testParameterHash, verbose )
            if( forceTest is not None and forceTest.isActive() ):

                if( serializeDirectory ):
                    forceTest.serialize( serializeDirectory )

                forceTest.runTest( platform )
                fullTestName                 = forceTest.getFullTestName()
                testResults[fullTestName]    = forceTest.getResult()
                if( 'Sleep' in argumentHash and argumentHash['Sleep'] > 0 ):
                    if( verbose ):
                        print "Sleeping %s seconds." % (str(argumentHash['Sleep'])) 
                        sys.stdout.flush()
                    time.sleep(float(argumentHash['Sleep']))

        except:

            print repr(traceback.print_exception( sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]))
            print "%s test did not run for %s" % (testName, systemName)
            numberOfTestsThatDidNotRun      += 1
     
# output results

print '*' * 80
sys.stdout.flush()
numberOfTestsThatPassed                  = 0
numberOfTestsThatFailed                  = 0

for testName,result in testResults.iteritems():
    if( result.testPassed( ) ):
        numberOfTestsThatPassed += 1
    else:
        numberOfTestsThatFailed += 1

print "\nSummary:\n"
print "   %d tests passed"       % numberOfTestsThatPassed
print "   %d tests failed"       % numberOfTestsThatFailed
print "   %d tests did not run"  % numberOfTestsThatDidNotRun

# print summary file

printTestSummary( testSummaryFileName, testNameHash, testResults )

# exit w/ flag signalling success

if numberOfTestsThatFailed > 0 or numberOfTestsThatDidNotRun > 0:
   sys.exit(1)   
else:
   sys.exit(0)
