#!/usr/local/bin/env python

#=============================================================================================
# MODULE DOCSTRING
#=============================================================================================

"""
Test energy conservation (Verlet) or temperature average (Langevin)
for various inegrators

DESCRIPTION

TODO

COPYRIGHT AND LICENSE

@author Mark Friedrichs <friedrim@stanford.edu>

All code in this repository is released under the GNU General Public License.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 
You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.

"""

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import sys
import os
import os.path
import math
import traceback
import time

import simtk.unit as units
import simtk.openmm as openmm
from ValidationUtilities     import ValidationUtilities
from OpenMMSystemParameters  import OpenMMSystemParameters

from ParameterFileParser     import ValidationParameterFileParser

import EnergyConservationTests

#=============================================================================================
# SUBROUTINES
#=============================================================================================

#=============================================================================================
# Print usage (command-line arguments)
#=============================================================================================

def printUsage( ):
    """Print command line arguments

    ARGUMENTS

    """

    firstTab    = 1
    secondTab   = 19
    thirdTab    = 60
    print "Command line arguments:\n"
    outputString  = ' '.rjust(firstTab) + "-platform".rjust(secondTab) + " Platform (Reference, Cuda, OpenCL)\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-controlFile".rjust(secondTab) + " Control file name\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-systemsDirectory".rjust(secondTab) + " Systems file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-configDirectory".rjust(secondTab) + " Config file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-dataDirectory".rjust(secondTab) + " Data file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-serializeDirectory".rjust(secondTab) + " Serialize file directory\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-testSummaryFileName".rjust(secondTab) + " summary file name\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-printParam".rjust(secondTab) + " if nonzero, then print parameters and exit\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-sleep ".rjust(secondTab) + " if nonzero, then sleep the specified time between tests\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-deviceId".rjust(secondTab) + " device id for gpu\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-verbose".rjust(secondTab) + " if nonzero, then be verbose\n".rjust(thirdTab)
    outputString += ' '.rjust(firstTab) + "-help".rjust(secondTab) + " current output\n".rjust(thirdTab)
    outputString += "\n\n"
    print outputString
    return 

# parse command-line arguments

def parseCommandLine( argumentHash ):

    argLen   = len( sys.argv )
    argLenM1 = argLen - 1
    skip     = 0
    verbose  = argumentHash['Verbose']
    for ii in range( 1,argLen ):  
        if( skip ): 
           skip = 0
           continue
        if( sys.argv[ii].startswith('-platform') and ii < argLenM1 ):
            skip = 1
            argumentHash['Platform'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-systemsDirectory') and ii < argLenM1 ):
            skip = 1
            systemsDirectory = sys.argv[ii+1]
            argumentHash['SystemsDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-configDirectory') and ii < argLenM1 ):
            skip = 1
            argumentHash['ConfigDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-dataDirectory') and ii < argLenM1 ):
            skip = 1
            argumentHash['DataDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-serializeDirectory') and ii < argLenM1 ):
            skip = 1
            serializeDirectory = sys.argv[ii+1]
            argumentHash['SerializeDirectory'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-deviceId') and ii < argLenM1 ):
            skip = 1
            argumentHash['DeviceId'] = int( sys.argv[ii+1] )

        elif( sys.argv[ii].startswith('-sleep') and ii < argLenM1 ):
            skip = 1
            argumentHash['Sleep'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-verbose') and ii < argLenM1 ):
            skip = 1
            if( int( sys.argv[ii+1] ) > 0 ):
                 argumentHash['Verbose']  = True 
            else:
                 argumentHash['Verbose']  = False

        elif( sys.argv[ii].startswith('-printParam') and ii < argLenM1 ):
            skip = 1
            if( int( sys.argv[ii+1] ) > 0 ):
                 argumentHash['PrintParam']  = True 
            else:
                 argumentHash['PrintParam']  = False

        elif( sys.argv[ii].startswith('-testSummaryFileName') and ii < argLenM1 ):
            skip = 1 
            argumentHash['TestSummaryFileName'] = sys.argv[ii+1]

        elif( sys.argv[ii].startswith('-controlFile') and ii < argLenM1 ):
            skip = 1
            argumentHash['ControlFileName']  = sys.argv[ii+1]
        elif( sys.argv[ii].startswith('-help') ):
            printUsage()
            exit(-1)
        else:
            print "Command line option %s not recognized." % (ii,sys.argv[ii])
            printUsage()
            exit(-1)

        if( verbose ):
            print "%3d %s" % (ii,sys.argv[ii])
    return

#=============================================================================================
# Get default settings
#=============================================================================================

def getDefaultSettings( ):
    """Default settings

    ARGUMENTS
        none

    """

    defaultSettings      = { 
                              'Active'                              :            1,
                              'EquilibrationTotalSteps'             :            30000,
                              'SimulationRandomNumberSeed'          :            1994,
                              'SimulationStepSize'                  :            0.001,
                              'SimulationTotalSteps'                :            1000000,
                              'SimulationTemperature'               :            300.0,
                              'SimulationShakeTolerance'            :            1.0e-06,
                              'SimulationFriction'                  :            91,
                              'SimulationErrorTolerance'            :            2.0e-06,
                              'SimulationConstraintTolerance'       :            1.0e-05,
                              'SimulationStepsBetweenReportsRatio'  :            0.01,
                              'Pressure'                            :            3.0,
                              'RandomNumberSeed'                    :            -1,
                              'VerletTolerance'                     :            2.0e-02,
                              'LangevinTolerance'                   :            3.0e-02,
                            }

    return defaultSettings 

#=============================================================================================
# Print a summary of the results to a file
#=============================================================================================

def printTestSummary( testSummaryFileName, testHash, testResults ):

    """Print summary of results to file

    ARGUMENTS
        testSummaryFileName  (string) summary file name
        testHash       (hash)   testHash[testName] = defaultValues
        testResults          (hash)   testResults[testName]    = testResult (EnergyForceComparisonResult object)
    """

    firstTab    = 40
    secondTab   = 40

    # open file

    summaryFile       = open( testSummaryFileName, 'w')
    writeHeaderString = 1

    # loop over tests

    trackTestNameHash = {}
    for fullTestName,argumentHash in testHash.iteritems():

        # header

        testName = argumentHash['TestName']
        if( argumentHash['Active'] > 0 and (testName not in trackTestNameHash) ):
            trackTestNameHash[testName] = 1
            isVerlet                    = 0
            if( testName.find( 'Verlet' ) > -1 ):
                tolerance    = argumentHash['VerletTolerance']
                headerString = testName.rjust(firstTab) + " Passed                Drift  Time   ns/day Tolerance=" + ("%10.3e" % tolerance)  
                isVerlet     = 1
            else:
                tolerance    = argumentHash['LangevinTolerance']
                headerString = testName.rjust(firstTab) + " Passed               DeltaT  Time   ns/day Tolerance=" + ("%10.3e" % tolerance)  
    
            if( writeHeaderString ):
                summaryFile.write( headerString + "\n\n" )
    
            # loop over results and print out results for any tests that matcvh this test type
    
            for fullTestName,result in testResults.iteritems():
                if( result.getTestName() == testName ):
                    result.verboseOff()
                    passed         = result.testPassed( )
                    result.verboseRestore()

                    if( writeHeaderString ):
                        outputString   = ''
                    else:
                        outputString   = testName.rjust(firstTab) + result.getPlatform().rjust(8)

                    outputString   += result.getSystemName().rjust(20) + " " + repr(passed).rjust(6)

                    if( isVerlet ):
                        outputString +=  ("%10.3e" % result.getDrift()).rjust(11)
                    else:
                        outputString +=  ("%10.3e" % result.getRelativeDeltaT()).rjust(11)

                    outputString +=  ("%10.3e" % result.getTotalTime()).rjust(11)
                    outputString +=  ("%10.3e" % result.getNsPerDay()).rjust(11)
                    outputString += (" Cnst=%6d" % result.getNumberOfConstraintViolations()).rjust(11)

                    summaryFile.write( outputString + "\n" )
    
            if( writeHeaderString ):
                summaryFile.write( "\n\n" )
    
    # close file

    summaryFile.close()

    return

#=============================================================================================
# MAIN
#=============================================================================================

argumentHash                             = {}
validationUtilities                      = ValidationUtilities(False)

validationDirectory                      = os.path.join(os.getenv('PYOPENMM_SOURCE_DIR'), 'test', 'validation')
useXml                                   = 1
if( useXml ):
    parameterFileSuffix                  = '.xml'
    systemsDirectory                     = os.path.join(validationDirectory, 'systems_xml')
else:
    parameterFileSuffix                  = '.txt'
    systemsDirectory                         = os.path.join(validationDirectory, 'systems')

configDirectory                          = os.path.join(validationDirectory, 'config')
dataDirectory                            = os.path.join(validationDirectory, 'data')

testListFileName                         = 'EnergyConservationTestNames.txt'
systemListFileName                       = 'SystemNames.txt'

# default parameters

defaultParameterHash                     = getDefaultSettings() 

argumentHash['ValidationDirectory']      = validationDirectory
argumentHash['ConfigDirectory']          = configDirectory
argumentHash['DataDirectory']            = dataDirectory
argumentHash['Sleep']                    = 0
argumentHash['TestListFileName']         = testListFileName
argumentHash['SystemListFileName']       = systemListFileName
argumentHash['SystemsDirectory']         = systemsDirectory
argumentHash['Platform']                 = 'Cuda'
argumentHash['PrintParam']               = False
argumentHash['Verbose']                  = False
argumentHash['ShowTests']                = 0
argumentHash['TestSummaryFileName']      = 'NA'

# parse command line

parseCommandLine( argumentHash )
systemsDirectory                         = argumentHash['SystemsDirectory']
printParam                               = argumentHash['PrintParam']
verbose                                  = argumentHash['Verbose']
validationUtilities.setVerbose( verbose )

if( 'SerializeDirectory' in argumentHash ):
    serializeDirectory                   = argumentHash['SerializeDirectory']
else:
    serializeDirectory                   = 0

# platform to run tests on

testPlatform                             = argumentHash['Platform']

# for each test load default parameters

fullTestListFileName                     = os.path.join( argumentHash['DataDirectory'], argumentHash['TestListFileName'] )
testList                                 = validationUtilities.getListFromFile( fullTestListFileName )

fullSystemListFileName                   = os.path.join( argumentHash['DataDirectory'], argumentHash['SystemListFileName'] )
systemList                               = validationUtilities.getListFromFile( fullSystemListFileName )

testHash                                 = validationUtilities.buildDefaultTestHash( testList, systemList, defaultParameterHash )
if( 'ControlFileName' in argumentHash ):
    fullControlFileName                  = os.path.join( argumentHash['ConfigDirectory'], argumentHash['ControlFileName'] )
    validationUtilities.editTestHashBasedOnControlFile( fullControlFileName, testHash )

if( verbose or printParam ):
    print "%s" % validationUtilities.printHash( argumentHash, "Arguments\n" )
    print "%s" % validationUtilities.printTestHash( testHash, 1 )
    if( printParam ):exit(0)

if( argumentHash['TestSummaryFileName'] != 'NA' ):
    testSummaryFileName = argumentHash['TestSummaryFileName']
else:
    testSummaryFileName = testPlatform + 'EnergyConservationSummary.txt'

# container for results

testResults                              = {}
testNameHash                             = {}

# track number of tests that do not run

numberOfTestsThatDidNotRun               = 0

#try:
#    validationUtilities            = ValidationUtilities(verbose)
#    #validationUtilities.serializeSystemToFile( self._system, "system.xml" )
#    #system                         = validationUtilities.deserializeSystemFromFile( "system.xml" )
#    system                         = validationUtilities.deserializeSystemFromFile( "VillinSystem.xml" )
#    print "Serialization particles=%d forces=%d\n" % (system.getNumParticles(), system.getNumForces())
#    for ii in range( 0, system.getNumForces()):  
#        force = system.getForce( ii )
##       print "%6d      %60s " % (ii, force)
#        forceClass = "%s" % ( force.__class__ )
##       print "%6d      %60s " % (ii, force.__class__)
#        components = forceClass.split('.')
##       for jj in range(len(components)):
##           print "%6d %6d      %60s " % (ii, jj, components[jj] )
#        forceName = components[len(components)-1].replace( "'>", '' )
#        print "%6d      %60s " % (ii, forceName )
#    #validationUtilities.serializeSystemToFile( system, "system2.xml" )
#except:
#    print "Serialization exception ", sys.exc_info()[0]
#    print "Serialization exception ", sys.exc_info()[1]
#sys.exit(0);

# loop over systems

for fullTestName, testParameterHash in testHash.iteritems():
    if( argumentHash['ShowTests'] and testParameterHash['Active'] != 0 ):
        print "%s %s active=%d" % ( testParameterHash['SystemName'], testParameterHash['TestName'], testParameterHash['Active'] )
    elif( testParameterHash['Active'] > 0 ):

        # sleep?

        if( 'Sleep' in argumentHash and argumentHash['Sleep'] > 0 ):
            if( verbose ):
                print "Sleeping %s seconds." % (str(argumentHash['Sleep'])) 
                sys.stdout.flush()
                time.sleep(float(argumentHash['Sleep']))

        systemName              = testParameterHash['SystemName']
        testName                = testParameterHash['TestName']
        testNameHash[testName]  = 1

        print '*' * 80
        print "%30s %30s" % (systemName, testName )
    
        # read parameter file

        parameterFileName        = systemName + parameterFileSuffix
        fullParameterFileName    = os.path.join(systemsDirectory, parameterFileName)
        if( useXml ):
            positionFileName         = systemName + '.txt'
            fullPositionFileName     = os.path.join(systemsDirectory, positionFileName)
            systemParameters         = OpenMMSystemParameters(fullParameterFileName, fullPositionFileName, verbose)
        else:
            systemParameters         = ValidationParameterFileParser(fullParameterFileName, verbose)
    
        # set device id

        if( 'DeviceId' in argumentHash ):
           testParameterHash['DeviceId'] = argumentHash['DeviceId']

        try:
            isVariableInegrator = 0;
            if( testName.startswith('VerletIntegratorNoConstraints') ):
                energyConservationTest   = EnergyConservationTests.VerletIntegratorNoConstraintsTest( systemParameters, testParameterHash, verbose )
            elif( testName.startswith('VerletIntegrator') ):
                energyConservationTest   = EnergyConservationTests.VerletIntegratorTest( systemParameters, testParameterHash, verbose )
            elif( testName.startswith('LangevinIntegrator') ):
                energyConservationTest   = EnergyConservationTests.LangevinIntegratorTest( systemParameters, testParameterHash, verbose )
            elif( testName.startswith('VariableVerletIntegrator') ):
                isVariableInegrator      = 1;
                energyConservationTest   = EnergyConservationTests.VariableVerletIntegratorTest( systemParameters, testParameterHash, verbose )
            elif( testName.startswith('VariableLangevinIntegrator') ):
                isVariableInegrator      = 1;
                energyConservationTest   = EnergyConservationTests.VariableLangevinIntegratorTest( systemParameters, testParameterHash, verbose )
            elif( testName.startswith('MonteCarloBarostatLangevinIntegrator') ):
                energyConservationTest   = EnergyConservationTests.MonteCarloBarostatLangevinIntegratorTest( systemParameters, testParameterHash, verbose )
            else:
                energyConservationTest   = None
                print "%s test not recognized." % testName
                numberOfTestsThatDidNotRun += 1
    
            # serialize?

            if( serializeDirectory ):
                energyConservationTest.serialize( serializeDirectory )

            # run test

            if( energyConservationTest is not None and energyConservationTest.isActive() ):
                if( isVariableInegrator ):
                    energyConservationTest.runVariableTest( testPlatform )
                else:
                    energyConservationTest.runTest( testPlatform )
                fullTestName                 = energyConservationTest.getFullTestName()
                testResults[fullTestName]    = energyConservationTest.getResult( )

                # sleep?

                if( 'Sleep' in argumentHash and argumentHash['Sleep'] > 0 ):
                    if( verbose ):
                        print "Sleeping %s seconds." % (str(argumentHash['Sleep'])) 
                        sys.stdout.flush()
                    time.sleep(float(argumentHash['Sleep']))

        except:

            print "%s test did not run for %s" % (testName, systemName)
            print repr(traceback.print_exception( sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]))
            numberOfTestsThatDidNotRun      += 1
     
if( argumentHash['ShowTests'] ):
    exit(0)

# report test results

print '*' * 80
numberOfTestsThatPassed                  = 0
numberOfTestsThatFailed                  = 0

# display results by test type

for testType, hasAppeared in testNameHash.iteritems():
    for testName,result in testResults.iteritems():
        if( result.getTestName() == testType ):

            # set test tolerance

            if( testName.find( 'Verlet' ) > -1 ):
                isVerlet     = 1
            else:
                isVerlet     = 0

            if( result.testPassed( ) ):
                numberOfTestsThatPassed += 1
            else:
                numberOfTestsThatFailed += 1

print "%d tests passed     " % numberOfTestsThatPassed
print "%d tests failed     " % numberOfTestsThatFailed
print "%d tests did not run" % numberOfTestsThatDidNotRun

# print summary file

printTestSummary( testSummaryFileName, testHash, testResults )

# exit w/ flag signalling success

if numberOfTestsThatFailed > 0 or numberOfTestsThatDidNotRun > 0:
   sys.exit(1)   
else:
   sys.exit(0)
