#!/usr/bin/env python
#
#
 
__author__ = "Randall J. Radmer"
__version__ = "1.0"
__doc__ = """Load AMBER param force field file."""
  

import sys, math
#


#Conversion constants
KJ_PER_KCAL = 4.1868
KCAL_PER_KJ = 1.0/KJ_PER_KCAL
 
A_PER_NM = 10.0
NM_PER_A = 1.0/A_PER_NM

RAD_PER_DEG = math.pi/180.
DEG_PER_RAD = 1.0/RAD_PER_DEG

SIGMA_PER_VDW_RADIUS = 2*math.pow(2., -1./6.)
VDW_RADIUS_PER_SIGMA = 1.0/SIGMA_PER_VDW_RADIUS


class Parser:
    """Param file object"""
    def __init__(self, filename):
        self.dParamsMass = {}
        self.dParamsBonds = {}
        self.dParamsBondAngles = {}
        self.dParamsDihedralAngles = {}
        self.dParamsImproperDihedralAngles = {}
        dParamsEquivalencedAtoms = {}
        self.dParams612Potential = {}
        self.namnbLabelKeys=[]
    
        fIn = open(filename)
        mode='start'
        for line0 in fIn:
            line=line0.rstrip()
            if mode=='start':
                mode='atom_mass'
            elif mode=='atom_mass':
                try:
                    atom, mass, comment = line.split(None, 2)
                    self.dParamsMass[atom]=(float(mass), comment)
                except ValueError:
                    mode='hydrophilic'
            elif mode=='hydrophilic':
                mode='bonds'
                # Done save hydrophilic atoms
            elif mode=='bonds':
                if line=='':
                    mode='bond_angles'
                    continue
                key=(line[:2].strip(),
                     line[3:5].strip())
                RK, REQ, comment = line[5:].split(None, 2)
                self.dParamsBonds[key]=(float(RK), float(REQ), comment)
            elif mode=='bond_angles':
                if line=='':
                    mode='dihedral_angle'
                    continue
                key=(line[:2].strip(),
                     line[3:5].strip(),
                     line[6:8].strip())
                try:
                    TK, TEQ, comment = line[8:].split(None, 2)
                except ValueError:
                    TK, TEQ = line[8:].split()
                    comment=''
                items=(float(TK), float(TEQ), comment)
                self.dParamsBondAngles[key]=items
            elif mode=='dihedral_angle':
                if line=='':
                    mode='improper_dihedral_angle'
                    continue
                key=(line[:2].strip(),
                     line[3:5].strip(),
                     line[6:8].strip(),
                     line[9:11].strip())
                try:
                    IDIVF, PK, PHASE, PN, comment = line[11:].split(None, 4)
                except ValueError:
                    IDIVF, PK, PHASE, PN = line[11:].split()
                    comment=''
                items = (int(IDIVF), float(PK),
                         float(PHASE), abs(float(PN)), comment)
                try:
                    self.dParamsDihedralAngles[key].append(items)
                except KeyError:
                    self.dParamsDihedralAngles[key]=[(items)]
            elif mode=='improper_dihedral_angle':
                if line=='':
                    mode='h_bonds'
                    continue
                key=(line[:2].strip(),
                     line[3:5].strip(),
                     line[6:8].strip(),
                     line[9:11].strip())
                try:
                    PK, PHASE, PN, comment = line[11:].split(None, 3)
                except ValueError:
                    PK, PHASE, PN = line[11:].split()
                    comment=''
                items = (float(PK), float(PHASE),
                         abs(float(PN)), comment)
                self.dParamsImproperDihedralAngles[key]=items
            elif mode=='h_bonds':
                if line=='':
                    mode='equivalenced_atoms'
                    continue
                # Done save 10-12 Potential interactions
            elif mode=='equivalenced_atoms':
                if line=='':
                    mode='pre_6-12_potential'
                    continue
                items=line.split()
                dParamsEquivalencedAtoms[items[0]]=items[1:]
            elif mode=='pre_6-12_potential':
                if line.rstrip()=="END": break
                LABEL , KINDNB = line.split()
                if KINDNB!='RE':
                    raise Exception, 'Error: KINDNB not equal to "RE"--only van der Waals radius and the potential well depth parameters are read.'
                namnbLabelKey=(LABEL, KINDNB)
                self.namnbLabelKeys.append(namnbLabelKey)
                self.dParams612Potential[namnbLabelKey]={}
                mode='6-12_potential'
            elif mode=='6-12_potential':
                if line=='':
                    mode='pre_6-12_potential'
                    continue
                LTYNB, R, EDEP, comment = line.split(None, 3)
                self.dParams612Potential[namnbLabelKey][LTYNB]=(float(R),
                                                               float(EDEP),
                                                               comment)
            else:
                raise Exception, 'Error: Bad "mode" value: %s' % mode
    
        for equAtom in dParamsEquivalencedAtoms:
            for atom in dParamsEquivalencedAtoms[equAtom]:
                for key in self.dParams612Potential:
                    if atom not in self.dParams612Potential[key]:
                        R, EDEP, comment = self.dParams612Potential[key][equAtom]
                        newComment= 'Equivalenced to atom %s' % equAtom
                        item = (R, EDEP, newComment)
                        self.dParams612Potential[key][atom] = item


    def getMass(self, atomType):
        mass, comment = self.dParamsMass[atomType]
        return mass


    def getBondParams(self, atomType1,
                            atomType2,
                            useKJ=False,
                            useNM=False):
        try:
            returnValue = self.dParamsBonds[(atomType1, atomType2)]
            if useKJ:
                returnValue = (returnValue[0] * KJ_PER_KCAL,
                               returnValue[1],
                               returnValue[2])
            if useNM:
                returnValue = (returnValue[0] / (NM_PER_A*NM_PER_A),
                               returnValue[1] * NM_PER_A,
                               returnValue[2])
            returnValue = (2*returnValue[0],
                             returnValue[1],
                             returnValue[2])
        except KeyError:
            returnValue = None

        return returnValue


    def getBondAngleParams(self, atomType1,
                                 atomType2,
                                 atomType3,
                                 useKJ=False,
                                 useRad=False):
        try:
            returnValue = self.dParamsBondAngles[(atomType1,
                                                 atomType2,
                                                 atomType3)]
            if useKJ:
                returnValue = (returnValue[0] * KJ_PER_KCAL,
                               returnValue[1],
                               returnValue[2])
            if useRad:
                returnValue = (returnValue[0],
                               returnValue[1] * RAD_PER_DEG,
                               returnValue[2])
            returnValue = (2*returnValue[0],
                             returnValue[1],
                             returnValue[2])
        except KeyError:
            returnValue = None
        return returnValue


    def getDihedralAngleParams(self, atomType1,
                                     atomType2,
                                     atomType3,
                                     atomType4,
                                     useKJ=False,
                                     useRad=False):
        try:
            returnValue = self.dParamsDihedralAngles[(atomType1,
                                                      atomType2,
                                                      atomType3,
                                                      atomType4)]
            if useKJ:
                newReturnValue=[]
                for items in returnValue:
                    newReturnValue.append( (items[0],
                                            items[1] * KJ_PER_KCAL,
                                            items[2],
                                            items[3],
                                            items[4]) )
                returnValue=newReturnValue
            if useRad:
                newReturnValue=[]
                for items in returnValue:
                    newReturnValue.append( (items[0],
                                            items[1],
                                            items[2] * RAD_PER_DEG,
                                            items[3],
                                            items[4]) )
                returnValue=newReturnValue
        except KeyError:
            returnValue = None
        return returnValue



    def getImproperDihedralAngles(self, atomType1,
                                        atomType2,
                                        atomType3,
                                        atomType4,
                                        useKJ=False,
                                        useRad=False):
        try:
            returnValue = self.dParamsImproperDihedralAngles[(atomType1,
                                                              atomType2,
                                                              atomType3,
                                                              atomType4)]
            if useKJ:
                returnValue = (returnValue[0] * KJ_PER_KCAL,
                               returnValue[1],
                               returnValue[2],
                               returnValue[3])
            if useRad:
                returnValue = (returnValue[0],
                               returnValue[1] * RAD_PER_DEG,
                               returnValue[2],
                               returnValue[3])
        except KeyError:
            returnValue = None
        return returnValue



    def get612Params(self, atomType,
                     useKJ=False,
                     useNM=False,
                     returnSigmaAsR=False,
                     namnbLabelKey=None):
        if not namnbLabelKey:
            namnbLabelKey=self.namnbLabelKeys[0]

        (R, EDEP, comment) = self.dParams612Potential[namnbLabelKey][atomType]

        if useKJ:
            EDEP = KJ_PER_KCAL * EDEP

        if useNM:
            R = NM_PER_A * R

        if returnSigmaAsR:
            R = SIGMA_PER_VDW_RADIUS * R

        return (R, EDEP)
    


def parseCommandLine():
    import getopt
    opts, args_proper = getopt.getopt(sys.argv[1:], 'h')
    for option, parameter in opts:
        if option=='-h': usageError()
    return (args_proper)

def main():
    args_proper = parseCommandLine()
    try:
        filename = args_proper[0]
    except IndexError:
        usageError()

    param=Parser(filename)

def usageError():
    import os
    print 'usage: %s inputPramFilename' \
         % os.path.basename(sys.argv[0])
    sys.exit(1)

if __name__=='__main__':
    main()

