#!/bin/env python

#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import os
import sys
import os.path
import copy
import re
import math

import simtk.unit as units
import simtk.openmm
from   ValidationUtilities import ValidationUtilities
from   ParameterFileParser import ValidationParameterFileParser

#=============================================================================================
# Parser for OpenMM validation parameter files
#=============================================================================================

class OpenMMSystemParameters(object):
    """Parses OpenMM xml and/or parameter file

    OpenMMSystemParameters reads and parse a validation parameter file.

    """
    def __init__(self, xmlSystemFilename, positionFileName, verbose=False):
        """

        ARGUMENTS

        xmlSystemFilename (string)  -  OpenMM xml serialization file
        positionFileName  (string)  -  position text file
        verobse           (boolean) -  be verbose

        """

        self._name                                 = None
        defaultBoxLength                           = 1000.0
        #self._boxVectors                           = [ [ defaultBoxLength, 0.0, 0.0 ], [ 0.0, defaultBoxLength, 0.0 ],[ 0.0, 0.0, defaultBoxLength, ] ]

        self._constraints                          = None
        self._positions                            = None
        self._cMMotionRemover                      = None
        self._verbose                              = verbose
        self._openMM                               = simtk.openmm

        # open xml file if it exists; otherwise exit

        if os.path.exists( xmlSystemFilename ):
            filePtr      =open(xmlSystemFilename)
            if verbose: print "Opened file %s" % xmlSystemFilename
        else:
            print "Could not open file %s" % xmlSystemFilename
            sys.exit(1)

        self._name                                 = os.path.basename( xmlSystemFilename ).replace('.xml','')
        self._path                                 = os.path.dirname( xmlSystemFilename )
        self._validationUtilities                  = ValidationUtilities(verbose)
        self._system                               = self._validationUtilities.deserializeSystemFromFile( xmlSystemFilename )
        self._masses                               = None
        self._buildForceMap()
        self._boxVectors                           = self._system.getDefaultPeriodicBoxVectors()
        lengthUnit                                 = self._validationUtilities.getLengthUnit()
        for ii in range( 3 ):
            self._boxVectors[ii]                  /= lengthUnit
        
        if( positionFileName is not None ):
            self._validationParameterFileParser    = ValidationParameterFileParser(positionFileName)
            self._positionFileName                 = positionFileName
        else:
            self._validationParameterFileParser    = None

    #=============================================================================================
    # Build force map
    #=============================================================================================

    def _buildForceMap(self):
        self._forceMap = {}
        for ii in range( 0, self._system.getNumForces()): 
            force      = self._system.getForce( ii )
            forceClass = "%s" % ( force.__class__ )
            components = forceClass.split('.')
            forceName  = components[len(components)-1].replace( "'>", '' )
            if( self._verbose ):
                print "%6d      %60s " % (ii, forceName )
            self._forceMap[forceName] = force
        return

    #=============================================================================================
    # Return name
    #=============================================================================================

    def getName(self):
        return self._name

    #=============================================================================================
    # Return array of masses
    #=============================================================================================

    def getMasses(self):
        if( self._masses is None ):
            self._numberOfMasses    = self._system.getNumParticles()   
            self._masses            = []
            for ii in range(self._numberOfMasses):
                self._masses.append(  self._system.getParticleMass( ii )  )
        return self._masses

    #=============================================================================================
    # Return box
    #=============================================================================================

    def getDefaultBoxVectors(self):
        return self._system.getDefaultPeriodicBoxVectors()

    #=============================================================================================
    # Get CMMotionRemover count
    #=============================================================================================

    def getCMMotionRemover(self):
        return self._cMMotionRemover

    #=============================================================================================
    # Return number of HarmonicBonds
    #=============================================================================================

    def getNumberOfHarmonicBonds(self):
        if( self.getHarmonicBondForce() ):
            return self.getHarmonicBondForce().getNumBonds()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM HarmonicBondForce
    #=============================================================================================

    def getHarmonicBondForce(self):
        if 'HarmonicBondForce' in self._forceMap:
            return self._forceMap['HarmonicBondForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of HarmonicAngles
    #=============================================================================================

    def getNumberOfHarmonicAngles(self):
        if( self.getHarmonicAngleForce() ):
            return self.getHarmonicAngleForce().getNumAngles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM HarmonicAngleForce
    #=============================================================================================

    def getHarmonicAngleForce(self):
        if 'HarmonicAngleForce' in self._forceMap:
            return self._forceMap['HarmonicAngleForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of PeriodicTorsions
    #=============================================================================================

    def getNumberOfPeriodicTorsions(self):
        if( self.getPeriodicTorsionForce() ):
            return self.getPeriodicTorsionForce().getNumTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM PeriodicTorsionForce 
    #=============================================================================================

    def getPeriodicTorsionForce(self):
        if 'PeriodicTorsionForce' in self._forceMap:
            return self._forceMap['PeriodicTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of RBTorsions
    #=============================================================================================

    def getNumberOfRBTorsions(self):
        if( self.getRBTorsionForce() ):
            return self.getRBTorsionForce().getNumTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM RBTorsionForce
    #=============================================================================================

    def getRBTorsionForce(self):
        if 'RBTorsionForce' in self._forceMap:
            return self._forceMap['RBTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of CMAPTorsions
    #=============================================================================================

    def getNumberOfCMAPTorsions(self):
        if( self.getCMAPTorsionForce() ):
            return self.getCMAPTorsionForce().getNumTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM CMAPTorsionForce
    #=============================================================================================

    def getCMAPTorsionForce(self):
        if 'CMAPTorsionForce' in self._forceMap:
            return self._forceMap['CMAPTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of Nonbondeds
    #=============================================================================================

    def getNumberOfNonbondeds(self):
        if( self.getNonbondedForce() ):
            return self.getNonbondedForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM NonbondedForce
    #=============================================================================================

    def getNonbondedForce(self):
        if 'NonbondedForce' in self._forceMap:
            return self._forceMap['NonbondedForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of gbsaObc particles
    #=============================================================================================

    def getNumberOfGbsaObcs(self):
        if( self.getGbsaObcForce() ):
            return self.getGbsaObcForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # Get OpenMM GbsaObcForce
    #=============================================================================================

    def getGbsaObcForce(self):
        if 'GBSAOBCForce' in self._forceMap:
            return self._forceMap['GBSAOBCForce'] 
        else:
            return None


    #=============================================================================================
    # Get OpenMM GBVIForce
    #=============================================================================================

    def getGbviForce(self):
        if 'GBVIForce' in self._forceMap:
            return self._forceMap['GBVIForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of gbsaObc particles
    #=============================================================================================

    def getNumberOfGbvis(self):
        if( self.getGbviForce() ):
            return self.getGbviForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # AMOEBA forces
    #=============================================================================================

    #=============================================================================================
    # Return number of AmoebaGeneralizedKirkwoodForce particles
    #=============================================================================================

    def getNumberOfAmoebaGeneralizedKirkwoodParticles(self):
        if( self.getAmoebaGeneralizedKirkwoodForce() ):
            return self.getAmoebaGeneralizedKirkwoodForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaGeneralizedKirkwoodForce
    #=============================================================================================

    def getAmoebaGeneralizedKirkwoodForce(self):
        if 'AmoebaGeneralizedKirkwoodForce' in self._forceMap:
            return self._forceMap['AmoebaGeneralizedKirkwoodForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaHarmonicAngles
    #=============================================================================================

    def getNumberOfAmoebaHarmonicAngles(self):
        if( self.getAmoebaHarmonicAngleForce() ):
            return self.getAmoebaHarmonicAngleForce().getNumAngles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaHarmonicAngleForce
    #=============================================================================================

    def getAmoebaHarmonicAngleForce(self):
        if 'AmoebaHarmonicAngleForce' in self._forceMap:
            return self._forceMap['AmoebaHarmonicAngleForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaHarmonicBonds
    #=============================================================================================

    def getNumberOfAmoebaHarmonicBonds(self):
        if( self.getAmoebaHarmonicBondForce() ):
            return self.getAmoebaHarmonicBondForce().getNumBonds()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaHarmonicBondForce
    #=============================================================================================

    def getAmoebaHarmonicBondForce(self):
        if 'AmoebaHarmonicBondForce' in self._forceMap:
            return self._forceMap['AmoebaHarmonicBondForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaHarmonicInPlaneAngless
    #=============================================================================================

    def getNumberOfAmoebaHarmonicInPlaneAngles(self):
        if( self.getAmoebaHarmonicInPlaneAngleForce() ):
            return self.getAmoebaHarmonicInPlaneAngleForce().getNumAngles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaHarmonicInPlaneAngleForce
    #=============================================================================================

    def getAmoebaHarmonicInPlaneAngleForce(self):
        if 'AmoebaHarmonicInPlaneAngleForce' in self._forceMap:
            return self._forceMap['AmoebaHarmonicInPlaneAngleForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaMultipoleForce particles
    #=============================================================================================

    def getNumberOfAmoebaMultipoleParticles(self):
        if( self.getAmoebaMultipoleForce() ):
            return self.getAmoebaMultipoleForce().getNumMultipoles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaMultipoleForce
    #=============================================================================================

    def getAmoebaMultipoleForce(self):
        if 'AmoebaMultipoleForce' in self._forceMap:
            return self._forceMap['AmoebaMultipoleForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaOutOfPlaneBendss
    #=============================================================================================

    def getNumberOfAmoebaOutOfPlaneBends(self):
        if( self.getAmoebaOutOfPlaneBendForce() ):
            return self.getAmoebaOutOfPlaneBendForce().getNumOutOfPlaneBends()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaOutOfPlaneBendForce
    #=============================================================================================

    def getAmoebaOutOfPlaneBendForce(self):
        if 'AmoebaOutOfPlaneBendForce' in self._forceMap:
            return self._forceMap['AmoebaOutOfPlaneBendForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaPiTorsionss
    #=============================================================================================

    def getNumberOfAmoebaPiTorsions(self):
        if( self.getAmoebaPiTorsionForce() ):
            return self.getAmoebaPiTorsionForce().getNumPiTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaPiTorsionForce
    #=============================================================================================

    def getAmoebaPiTorsionForce(self):
        if 'AmoebaPiTorsionForce' in self._forceMap:
            return self._forceMap['AmoebaPiTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaStretchBendForcess
    #=============================================================================================

    def getNumberOfAmoebaStretchBendForces(self):
        if( self.getAmoebaStretchBendForce() ):
            return self.getAmoebaStretchBendForce().getNumStretchBends()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaStretchBendForce
    #=============================================================================================

    def getAmoebaStretchBendForce(self):
        if 'AmoebaStretchBendForce' in self._forceMap:
            return self._forceMap['AmoebaStretchBendForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaTorsionss
    #=============================================================================================

    def getNumberOfAmoebaTorsions(self):
        if( self.getAmoebaTorsionForce() ):
            return self.getAmoebaTorsionForce().getNumTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaTorsionForce
    #=============================================================================================

    def getAmoebaTorsionForce(self):
        if 'AmoebaTorsionForce' in self._forceMap:
            return self._forceMap['AmoebaTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaTorsionTorsionss
    #=============================================================================================

    def getNumberOfAmoebaTorsionTorsions(self):
        if( self.getAmoebaTorsionTorsionForce() ):
            return self.getAmoebaTorsionTorsionForce().getNumTorsionTorsions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaTorsionTorsionForce
    #=============================================================================================

    def getAmoebaTorsionTorsionForce(self):
        if 'AmoebaTorsionTorsionForce' in self._forceMap:
            return self._forceMap['AmoebaTorsionTorsionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaUreyBradleyss
    #=============================================================================================

    def getNumberOfAmoebaUreyBradleys(self):
        if( self.getAmoebaUreyBradleyForce() ):
            return self.getAmoebaUreyBradleyForce().getNumInteractions()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaUreyBradleyForce
    #=============================================================================================

    def getAmoebaUreyBradleyForce(self):
        if 'AmoebaUreyBradleyForce' in self._forceMap:
            return self._forceMap['AmoebaUreyBradleyForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaVdwForce particles
    #=============================================================================================

    def getNumberOfAmoebaVdwParticles(self):
        if( self.getAmoebaVdwForce() ):
            return self.getAmoebaVdwForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaVdwForce
    #=============================================================================================

    def getAmoebaVdwForce(self):
        if 'AmoebaVdwForce' in self._forceMap:
            return self._forceMap['AmoebaVdwForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of AmoebaWcaDispersionForce particles
    #=============================================================================================

    def getNumberOfAmoebaWcaDispersionParticles(self):
        if( self.getAmoebaWcaDispersionForce() ):
            return self.getAmoebaWcaDispersionForce().getNumParticles()
        else:
            return 0

    #=============================================================================================
    # Return OpenMM AmoebaWcaDispersionForce
    #=============================================================================================

    def getAmoebaWcaDispersionForce(self):
        if 'AmoebaWcaDispersionForce' in self._forceMap:
            return self._forceMap['AmoebaWcaDispersionForce'] 
        else:
            return None

    #=============================================================================================
    # Return number of constraints
    #=============================================================================================

    def getNumberOfConstraints(self):
        if( self._constraints is None):
            self.getConstraints()
        return len(self._constraints)

    #=============================================================================================
    # Get constraints array; each entry contains (particle1 index, particle2 index, distance)
    #=============================================================================================

    def getConstraints(self):
        if( self._constraints is None ):
            self._constraints = []
            for ii in range(self._system.getNumConstraints()):
                args = self._system.getConstraintParameters( ii )
                self._constraints.append(  [ args[0], args[1], args[2]]  )
        return self._constraints;

    #=============================================================================================
    # Positions 
    #=============================================================================================

    def getPositions(self):
        return self._validationParameterFileParser.getVec3Array('Positions');

    #=============================================================================================
    # Vec3 array 
    #=============================================================================================

    def getVec3Array(self, name):
        return self._validationParameterFileParser.getVec3Array(name);

    #=============================================================================================
    # Scalar
    #=============================================================================================

    def getScalar(self, name):
        return self._validationParameterFileParser.getScalar(name);
