#
#
#

__author__ = "Randall J. Radmer"
__version__ = "1.0"
__doc__ = """Class to manage parameters and topology for amber simulations."""


import sys, os, math
#

import simtk.utils.chemOps as chemOps
import simtk.utils.pdbParser as pdbParser
import simtk.utils.amber.paramFileLoader as paramLoader
import simtk.utils.amber.prepFileLoader as prepLoader


class AmberSystem():
    """Builds and manages the full amber param/top/crd data set"""
    def __init__(self, 
                 pdbFilename,
                 paramFilename,
                 prepFilename,
                 ntPrepFilename=None,
                 ctPrepFilename=None,
                 arMappingFilename=None,
                 tryUsingCuda=True,
                 NCNB=2.0,
                 NCEE=1.2,
                 useGBSA=False,
                 temperature=None,
                 collisionFrequency=0.25,
                 randomSeed=None,
                 cutoff=None,
                 useNoCutoff=False,
                 useCutoffNonPeriodic=False,
                 useCutoffPeriodic=False):
        self.usingCuda=False
        self.NCNB=NCNB
        self.NCEE=NCEE
        self.temperature=temperature
        self.collisionFrequency=collisionFrequency
        self.randomSeed=randomSeed

        import simtk.utils.manageGPUs as gMan
        (self.mm, self.usingCuda) = gMan.importBestOpenMMLib(
                                         tryUsingCuda=tryUsingCuda)
        self.mmSystem=self.mm.System()

        self.pdbOrig = pdbParser.PDB(filename=pdbFilename)
        self.param = paramLoader.Parser(paramFilename)
        self.prep = prepLoader.Parser(prepFilename)
        if self.prep.itypf!=2:
            raise Exception, 'Input prep file is incorrect type'
        if ntPrepFilename:
            self.ntPrep = prepLoader.Parser(ntPrepFilename)
            if self.ntPrep.itypf!=200:
                raise Exception, 'Input ntPrep file is incorrect type (%d)' % self.ntPrep.itypf
        else:
            self.ntPrep = None
        if ctPrepFilename:
            self.ctPrep = prepLoader.Parser(ctPrepFilename)
            if self.ctPrep.itypf!=201:
                raise Exception, 'Input ctPrep file is incorrect type (%d)' % self.ctPrep.itypf
        else:
            self.ctPrep = None

        if arMappingFilename:
            self.arMapping = parseAtomResidueMappingFile(arMappingFilename)
        else:
            self.arMapping = None
        self.pdb = applyMappingToPdb(self.pdbOrig, self.arMapping)
        self.pdb = self.addMissingAtoms()

        self.atomPntDict = self.getMappingFromIGraphToRealAtom()

        (self.cPairs,
         self.cTriplets,
         self.cQuads,
         self.excludedAtomPairs,
         self.excluded14AtomPairs) = self.buildConnectLists()

        self.addParticlesAndForceField(useGBSA=useGBSA,
                                       cutoff=cutoff,
                                       useNoCutoff=useNoCutoff,
                                       useCutoffNonPeriodic=useCutoffNonPeriodic,
                                       useCutoffPeriodic=useCutoffPeriodic)


    def getPrep(self, resIndex=None, atomPnt=None):
        if (resIndex is not None and atomPnt is not None):
            raise Exception, "Error: must set only one of resIndex or atomPnt"
        if resIndex is not None:
            if self.ntPrep and resIndex in self.pdb.chainBreaks:
                prep=self.ntPrep
            elif self.ctPrep and resIndex+1 in self.pdb.chainBreaks:
                prep=self.ctPrep
            else:
                prep=self.prep
        elif atomPnt is not None:
            prep = self.getPrep(resIndex=self.pdb.resIndexList[atomPnt])
        else:
            raise Exception, "Error: must set resIndex or atomPnt"
        return prep

    def getMainAtoms(self, resIndex):
        resName = self.pdb.resNameList[resIndex]
        prep = self.getPrep(resIndex=resIndex)
        return prep.getMainAtoms(resName)

    def getFirstAndLastMainAtoms(self, resIndex):
        mainAtoms = self.getMainAtoms(resIndex)
        return (mainAtoms[0], mainAtoms[-1])


    def getAtomPropertiesByAtomPnt(self, atomPnt):
        resName=self.pdb.resNames[atomPnt]
        atomName=self.pdb.atomNames[atomPnt]
        prep = self.getPrep(atomPnt=atomPnt)
        return prep.getAtomPropertiesByName(resName, atomName)


    def getMappingFromIGraphToRealAtom(self):
        atomPntDict = {}
        for atomPnt in range(self.pdb.numAtomPnts):
            (iGraph,  atomType, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(atomPnt)

            resIndex=self.pdb.resIndexList[atomPnt]
            if resIndex not in atomPntDict:
                atomPntDict[resIndex]={}
            atomPntDict[resIndex][iGraph] = atomPnt
        return atomPntDict


    def buildConnectLists(self):
        cPairs=[]
        oldLastMain=None
        for resIndex in range(max(self.pdb.resIndexList)+1):
            resName = self.pdb.resNameList[resIndex]
            firstMain, lastMain = self.getFirstAndLastMainAtoms(resIndex)
            if oldLastMain is not None and resIndex not in self.pdb.chainBreaks:
                ii=self.atomPntDict[resIndex-1][oldLastMain]
                jj=self.atomPntDict[resIndex][firstMain]
                cPairs.append( (min(ii,jj), max(ii,jj)) )
            oldLastMain=lastMain

            prep=self.getPrep(resIndex=resIndex)
            for i, j in prep.getConnectedAtomPairs(resName):
                ii=self.atomPntDict[resIndex][i]
                jj=self.atomPntDict[resIndex][j]
                cPairs.append( (min(ii,jj), max(ii,jj)) )
        cPairs.sort()
        excludedAtomPairs=cPairs[:]

        cTriplets=[]
        numPairs=len(cPairs)
        for ii in range(numPairs):
            for jj in range(ii):
                item=None
                for iii in range(2):
                    iiiR = 1-iii
                    for jjj in range(2):
                        jjjR = 1-jjj
                        if cPairs[ii][iii ]==cPairs[jj][jjj ] and \
                           cPairs[ii][iiiR]!=cPairs[jj][jjjR]:
                            item = min( (cPairs[ii][iiiR],
                                         cPairs[jj][jjj ],
                                         cPairs[jj][jjjR]),
                                        (cPairs[jj][jjjR],
                                         cPairs[jj][jjj ],
                                         cPairs[ii][iiiR]) )
                if item and item not in cTriplets:
                    cTriplets.append(item)
                if item and (item[0], item[-1]) not in excludedAtomPairs:
                    excludedAtomPairs.append( (item[0], item[-1]) )
        cTriplets.sort()
        excludedAtomPairs.sort()

        cQuads=[]
        excluded14AtomPairs=[]
        numTriplets=len(cTriplets)
        for ii in range(numTriplets):
            for jj in range(numPairs):
                item=None
                for iii in range(0, 3, 2):
                    iiiR  = 2-iii
                    for jjj in range(2):
                        jjjR = 1-jjj
                        if cTriplets[ii][iii ]==cPairs[jj][jjj ] and \
                           cTriplets[ii][1   ]!=cPairs[jj][jjjR] and \
                           cTriplets[ii][iiiR]!=cPairs[jj][jjjR]:
                            item = min( (cTriplets[ii][iiiR],
                                         cTriplets[ii][1   ],
                                         cPairs[jj][jjj ],
                                         cPairs[jj][jjjR]),
                                        (cPairs[jj][jjjR],
                                         cPairs[jj][jjj ],
                                         cTriplets[ii][1   ],
                                         cTriplets[ii][iiiR]) )
                if item and item not in cQuads:
                    cQuads.append(item)
                if item and (item[0], item[-1]) not in excluded14AtomPairs:
                    excluded14AtomPairs.append( (item[0], item[-1]) )
        cQuads.sort()
        excluded14AtomPairs.sort()

        return (cPairs, cTriplets, cQuads,
                excludedAtomPairs, excluded14AtomPairs)


    def addParticlesAndForceField(self, useGBSA,
                                        cutoff=None,
                                        useNoCutoff=False,
                                        useCutoffNonPeriodic=False,
                                        useCutoffPeriodic=False):
        atomExceptions=[]
        self.nbs = self.mm.NonbondedForce()
        numMethodSets=0
        if useNoCutoff:
            self.nbs.setNonbondedMethod(self.nbs.NoCutoff)
            numMethodSets+=1
        if useCutoffNonPeriodic:
            self.nbs.setNonbondedMethod(self.nbs.CutoffNonPeriodic)
            self.nbs.setCutoffDistance(cutoff)
            numMethodSets+=1
        if useCutoffPeriodic:
            self.nbs.setNonbondedMethod(self.nbs.CutoffPeriodic)
            self.nbs.setCutoffDistance(cutoff)
            numMethodSets+=1
        if numMethodSets>1:
            raise Exception, "Error: Set only one Nonbonded Method"
        if useGBSA:
            self.gbsa = self.mm.GBSAOBCForce()
        else:
            self.gbsa = None
        for atomPnt in range(self.pdb.numAtomPnts):
            (iGraph,  atomType, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(atomPnt)

            self.mmSystem.addParticle(self.param.getMass(atomType))
            r, eps = self.param.get612Params(atomType,
                                             useKJ=True,
                                             useNM=True,
                                             returnSigmaAsR=True)

            self.nbs.addParticle(chg, r, eps)
            if self.gbsa:
                self.gbsa.addParticle(chg, 0.3, 1.0)

        # Add Bonds
        self.bonds = self.mm.HarmonicBondForce()
        for ii, jj in self.cPairs:
            (iGraph,  atomTypeII, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(ii)
            (iGraph,  atomTypeJJ, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(jj)
            key1=(atomTypeII, atomTypeJJ)
            key2=(atomTypeJJ, atomTypeII)
            items=None
            for key in (key1, key2):
                items = self.param.getBondParams(key[0],
                                                 key[1],
                                                 useKJ=True,
                                                 useNM=True)
                if items: break
            if not  items:
                raise Exception, "Cannot find bond type %s %s" % key1
            (bondK, bondR, comment) = items
            #print 'Add bond:', ii, jj, bondK, bondR
            self.bonds.addBond(ii, jj, bondR, bondK)
            if (ii, jj) not in atomExceptions:
                atomExceptions.append((ii, jj))
                self.nbs.addException(ii, jj, 0, 1, 0)


        # Add Angles
        self.angles = self.mm.HarmonicAngleForce()
        for ii, jj, kk in self.cTriplets:
            (iGraph,  atomTypeII, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(ii)
            (iGraph,  atomTypeJJ, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(jj)
            (iGraph,  atomTypeKK, iTree, chg) = \
                self.getAtomPropertiesByAtomPnt(kk)
            key1=(atomTypeII, atomTypeJJ, atomTypeKK)
            key2=(atomTypeKK, atomTypeJJ, atomTypeII)
            items=None
            for key in (key1, key2):
                items = self.param.getBondAngleParams(key[0],
                                                      key[1],
                                                      key[2],
                                                      useKJ=True,
                                                      useRad=True)
                if items: break
            if not items:
                raise Exception, "Cannot find angle type %s %s %s" % key1
            (angleK, angle0, comment) = items
            #print 'Add angle:', ii, jj, kk, angleK, angle0
            self.angles.addAngle(ii, jj, kk, angle0, angleK)
            if (ii, kk) not in atomExceptions:
                atomExceptions.append((ii, kk))
                self.nbs.addException(ii, kk, 0, 1, 0)


        # Add Dihedrals
        self.dihedrals = self.mm.PeriodicTorsionForce()
        for ii, jj, kk, ll in self.cQuads:
            (iGraph,  atomTypeII, iTree, chgII) = \
                self.getAtomPropertiesByAtomPnt(ii)
            (iGraph,  atomTypeJJ, iTree, chgJJ) = \
                self.getAtomPropertiesByAtomPnt(jj)
            (iGraph,  atomTypeKK, iTree, chgKK) = \
                self.getAtomPropertiesByAtomPnt(kk)
            (iGraph,  atomTypeLL, iTree, chgLL) = \
                self.getAtomPropertiesByAtomPnt(ll)
            key1=(atomTypeII, atomTypeJJ, atomTypeKK, atomTypeLL)
            key2=(atomTypeLL, atomTypeKK, atomTypeJJ, atomTypeII)
            key3=('X', atomTypeJJ, atomTypeKK, 'X')
            key4=('X', atomTypeKK, atomTypeJJ, 'X')
            itemsList=None
            for key in (key1, key2, key3, key4):
                itemsList = self.param.getDihedralAngleParams(key[0],
                                                              key[1],
                                                              key[2],
                                                              key[3],
                                                              useKJ=True,
                                                              useRad=True)
                if itemsList: break
            if not  itemsList:
                raise Exception, "Cannot find dihedral type %s %s %s %s" % key1
            for items in itemsList:
                (idivf, pk, phase, pn, comment) = items
                #print 'Add dihedral:', ii, jj, kk, ll, idivf, pk, phase, pn
                self.dihedrals.addTorsion(ii, jj, kk, ll,
                                          int(pn), phase, pk/float(idivf))
            (idivf, pk, phase, pn, comment) = itemsList[0]

            rII, epsII = self.param.get612Params(atomTypeII,
                                                 useKJ=True,
                                                 useNM=True,
                                                 returnSigmaAsR=True)
            rLL, epsLL = self.param.get612Params(atomTypeLL,
                                                 useKJ=True,
                                                 useNM=True,
                                                 returnSigmaAsR=True)

            chargeProd=chgII*chgLL/self.NCEE
            sigma=(rII+rLL)/2.0
            epsilon=math.sqrt(epsII*epsLL)/self.NCNB
            if (ii, ll) not in atomExceptions:
                atomExceptions.append((ii, ll))
                self.nbs.addException(ii, ll, chargeProd, sigma, epsilon)


        # Add Impropers
        for resIndex in range(max(self.pdb.resIndexList)+1):
            resName = self.pdb.resNameList[resIndex]
            firstMain, lastMain = self.getFirstAndLastMainAtoms(resIndex)
            prep=self.getPrep(resIndex=resIndex)
            for improper in prep.dPrepImpropers[resName]:
                atomPntList=[]
                atomTypeList=[]
                for atomIndex in improper:
                    try:
                        atomIndex=int(atomIndex)
                        atomPnt=self.atomPntDict[resIndex][atomIndex]
                    except ValueError:
                        if atomIndex=='-M':
                            resName = self.pdb.resNameList[resIndex-1]
                            firstMain, lastMain = self.getFirstAndLastMainAtoms(resIndex-1)
                            atomPnt=self.atomPntDict[resIndex-1][lastMain]
                        elif atomIndex=='+M':
                            resName = self.pdb.resNameList[resIndex+1]
                            firstMain, lastMain = self.getFirstAndLastMainAtoms(resIndex+1)
                            atomPnt=self.atomPntDict[resIndex+1][firstMain]
                        else:
                            raise Exception, 'Bad improper type: %s' % atomIndex
                    (iGraph,  atomType, iTree, chg) = \
                             self.getAtomPropertiesByAtomPnt(atomPnt)
                    atomPntList.append(atomPnt)
                    atomTypeList.append(atomType)
                keys=[]
                keys.append(tuple(atomTypeList))
                keys.append((atomTypeList[3], atomTypeList[2],
                             atomTypeList[1], atomTypeList[0]))
                keys.append(('X', atomTypeList[1],
                             atomTypeList[2], atomTypeList[3]))
                keys.append(('X', atomTypeList[2],
                             atomTypeList[1], atomTypeList[0]))
                keys.append(('X', 'X', atomTypeList[2], atomTypeList[3]))
                keys.append(('X', 'X', atomTypeList[1], atomTypeList[0]))
                items=None
                reverseLists=0
                for ii in range(len(keys)):
                    key=keys[ii]
                    reverseLists=ii%2
                    items = self.param.getImproperDihedralAngles(key[0],
                                                                 key[1],
                                                                 key[2],
                                                                 key[3],
                                                                 useKJ=True,
                                                                 useRad=True)
                    if items: break
                if reverseLists==0:
                    atom0=atomPntList[0]
                    atom1=atomPntList[1]
                    atom2=atomPntList[2]
                    atom3=atomPntList[3]
                else:
                    atom0=atomPntList[3]
                    atom1=atomPntList[2]
                    atom2=atomPntList[1]
                    atom3=atomPntList[0]
                pk, phase, pn, comment = items
                #print 'Add improper dihedral:', atom0, atom1, atom2, atom3, \
                #                                pk, phase, pn
                self.dihedrals.addTorsion(atom0, atom1, atom2, atom3,
                                          int(pn), phase, pk)

        
        self.nbs.thisown=False
        if self.gbsa:
            self.mmSystem.addForce(self.gbsa)
            self.gbsa.thisown=False
        self.mmSystem.addForce(self.bonds)
        self.bonds.thisown=False
        self.mmSystem.addForce(self.angles)
        self.angles.thisown=False
        self.mmSystem.addForce(self.dihedrals)
        self.dihedrals.thisown=False

        #Add temp coupling to the OpenMM system
        if self.temperature is not None:
            #print "Temp: ", self.temperature
            self.thermostat = self.mm.AndersenThermostat(
                                               self.temperature,
                                               self.collisionFrequency)
            if self.randomSeed:
                self.thermostat.setRandomNumberSeed(self.randomSeed)
            self.mmSystem.addForce(self.thermostat)
            self.thermostat.thisown=False


    def addMissingAtoms(self):
        resNameByResIndex={}
        coords={}
        addedAtoms={}
        for atomPnt in range(self.pdb.numAtomPnts):
            prep = self.getPrep(atomPnt=atomPnt)
            resIndex=self.pdb.resIndexList[atomPnt]
            resName=self.pdb.resNames[atomPnt]
            atomName=self.pdb.atomNames[atomPnt]
            if resIndex not in resNameByResIndex:
                 resNameByResIndex[resIndex]=resName
                 coords[resIndex]={}
                 addedAtoms[resIndex]=[]
            atomIndex=prep.dPrepAtomNameToIndex[resName][atomName]
            coords[resIndex][atomIndex]=self.pdb.coords[atomPnt]

        for resIndex in resNameByResIndex:
            prep = self.getPrep(resIndex=resIndex)
            resName=resNameByResIndex[resIndex]
            print "residue:", resIndex, resName
            numAtoms=prep.getNumAtomsByResName(resName)
            zMat=prep.dPrepZMat[resName]
            reverseZMat=buildReverseZMat(zMat)

            c = coords[resIndex]
            mainAtoms = self.getMainAtoms(resIndex-1)
            if resIndex not in self.pdb.chainBreaks:
                if mainAtoms[-3] in coords[resIndex-1]:
                    c[1]=coords[resIndex-1][mainAtoms[-3]]
                if mainAtoms[-2] in coords[resIndex-1]:
                    c[2]=coords[resIndex-1][mainAtoms[-2]]
                if mainAtoms[-1] in coords[resIndex-1]:
                    c[3]=coords[resIndex-1][mainAtoms[-1]]
            numCoordsLastPass=None
            while len(c)<numAtoms:
                zMatList=[zMat, reverseZMat]
                for zM in zMatList:
                    for index, na, nb, nc in zM:
                        if index not in c and \
                           na in c and \
                           nb in c and \
                           nc in c:
                            key=(index, na, nb, nc)
                            bondLength=zM[key][0]
                            angle=zM[key][1]
                            dAngle=zM[key][2]
                            dAngle+=self.getDihedralAngleOffset(resIndex,
                                                                coords,
                                                                zMatList,
                                                                index,
                                                                na, nb, nc)
                            c[index] = \
                                chemOps.addAtom(c[na], c[nb], c[nc],
                                                bondLength=bondLength,
                                                angle=angle,
                                                dihedralAngle=dAngle)
                            if index>3:
                                addedAtoms[resIndex].append(index)
                if numCoordsLastPass==len(c):
                    for atomIndex in range(1, numAtoms):
                        if atomIndex not in c:
                            (IGRAPH, ISYMBL, ITREE, chg) = \
                                prep.getAtomPropertiesByIndex(resName, atomIndex)
                            raise Exception, 'Atom %s in residue %s--%d is missing and connot be added\n' \
                                            % (self.pdb.atomNames[resIndex],
                                               resName,
                                               self.pdb.resNums[resIndex])
                    break
                numCoordsLastPass=len(c)
            if resIndex not in self.pdb.chainBreaks:
                coords[resIndex-1][mainAtoms[-3]]=c[1]
                coords[resIndex-1][mainAtoms[-2]]=c[2]
                coords[resIndex-1][mainAtoms[-1]]=c[3]

        sPDB=''
        lastAtomLine=None
        atomPnt = 0
        atomCount = 0
        for line in self.pdb.pdbLines:
            if line.find('END')==0 or line.find('TER')==0:
                (sPDB, atomCount) = \
                    self.addAtomsToPDBString(resIndexLast,
                                             self.pdb.resNames[atomPntLast],
                                             addedAtoms, coords,
                                             lastAtomLine,
                                             sPDB, atomCount)
                if line.find('END')==0:
                    sPDB += "END\n"
                if line.find('TER')==0:
                    sPDB += "TER\n"
            elif len(line)>=54 and (line[:4]=='ATOM' or line[:4]=='HETATM'):
                resIndex = self.pdb.resIndexList[atomPnt]
                atomPntLast=atomPnt-1
                if atomPntLast>=0:
                    resIndexLast = self.pdb.resIndexList[atomPntLast]
                    if resIndex != resIndexLast and addedAtoms[resIndexLast]:
                        prep = self.getPrep(atomPnt=atomPntLast)
                        resName = self.pdb.resNames[atomPntLast]
                        addedAtoms[resIndexLast].sort()
                        (sPDB, atomCount) = \
                            self.addAtomsToPDBString(resIndexLast, resName,
                                                     addedAtoms, coords,
                                                     lastAtomLine,
                                                     sPDB, atomCount)
                sPDB += '%s%5d%s\n' % (line[:6], atomCount+1, line[11:54])
                atomCount += 1

                lastAtomLine=line
                atomPnt += 1
            else:
                sPDB += line

        (sPDB, atomCount) = \
            self.addAtomsToPDBString(resIndexLast,
                                     self.pdb.resNames[atomPntLast],
                                     addedAtoms, coords,
                                     lastAtomLine,
                                     sPDB, atomCount)
        return pdbParser.PDB(asString=sPDB)



    def getDihedralAngleOffset(self, resIndex, coords, zMatList, indexMissing,
                               naMissing, nbMissing, ncMissing):
        dAngleOffsetKey = (resIndex, naMissing, nbMissing, ncMissing)
        try:
            if dAngleOffsetKey in self.dAngleOffset:
                return self.dAngleOffset[dAngleOffsetKey]
        except AttributeError:
                self.dAngleOffset={}

        offset=None
        for zM in zMatList:
            if offset is not None:
                break
            for index, na, nb, nc in zM:
                if index != indexMissing and \
                   na == naMissing and \
                   nb == nbMissing and \
                   nc == ncMissing and \
                   index in coords[resIndex]:
                    dihedralAngleRef=zM[(index, na, nb, nc)][2]
                    dihedralAngle = \
                        chemOps.getDihedralAngle(coords[resIndex][index],
                                                 coords[resIndex][na],
                                                 coords[resIndex][nb],
                                                 coords[resIndex][nc])
                    offset = (dihedralAngle%360 - dihedralAngleRef%360) % 360
                    break

        if offset is None:
            offset = 0

        self.dAngleOffset[dAngleOffsetKey] = offset
        return offset


    def addAtomsToPDBString(self, resIndex, resName,
                             addedAtoms, coords,
                             lastAtomLine,
                             sPDB, atomCount):
        prep = self.getPrep(resIndex=resIndex)
        for addedAtomPnt in addedAtoms[resIndex]:
            (atomName, ISYMBL, ITREE, chg) = \
                prep.getAtomPropertiesByIndex(resName,
                                              addedAtomPnt)
            if len(atomName)>=4:
                atomName = '%s%s' % (atomName[3],atomName[:3])
            else:
                atomName = ' %-3s' % atomName
    
            (x, y, z) = coords[resIndex][addedAtomPnt]
            sPDB += '%s%5d %4s%s%8.3f%8.3f%8.3f\n' \
                % (lastAtomLine[:6], atomCount+1, atomName,
                   lastAtomLine[16:30], x, y, z)
            atomCount += 1
        addedAtoms[resIndex]=[]
        return (sPDB, atomCount)


def buildReverseZMat(zMat):
    reverseZMat =  {}
    for index0, na0, nb0, nc0 in zMat:
        keyR=(nc0, nb0, na0, index0)
        if keyR in zMat:
            continue
        if index0<1 or na0<1 or nb0<1 or nc0<1:
            continue
        params0=zMat[(index0, na0, nb0, nc0)]
        for index1, na1, nb1, nc1 in zMat:
            if nb1!=nc0 or na1!=nb0 or index1!=na0:
                continue
            params1=zMat[(index1, na1, nb1, nc1)]
            for index2, na2, nb2, nc2 in zMat:
                if na2!=nc0 or index2!=nb0:
                    continue
                params2=zMat[(index2, na2, nb2, nc2)]
                reverseZMat[keyR]=(params2[0],
                                   params1[1],
                                   params0[2])
    return reverseZMat


def applyMappingToPdb(pdb, arMapping):
    s=''
    for line in pdb.pdbLines:
        if line.find('ATOM')!=0 and \
           line.find('HETATM')!=0 and \
           line.find('TER')!=0 and \
           line.find('END')!=0:
            continue
        line = line[:54]
        resNum = line[22:26].strip()
        resName = line[17:21].strip()
        atomName = line[12:16].strip()
        resNumAtomName = '%s:%s' % (resNum, atomName)
        resNameAtomName = '%s:%s' % (resName, atomName)
        newResName = ''
        newAtomName = ''
        if resNumAtomName in arMapping:
            newResName, newAtomName = arMapping[resNumAtomName].split(':')
        elif resNameAtomName in arMapping:
            newResName, newAtomName = arMapping[resNameAtomName].split(':')
        elif resNum in arMapping:
            newResName = arMapping[resNum]
        elif resName in arMapping:
            newResName = arMapping[resName]

        if newResName:
            newResName=newResName[:4]
            if len(newResName)==4:
                line = "%s%4s%s" % (line[:17], newResName, line[21:])
            else:
                line = "%s%3s%s" % (line[:17], newResName, line[20:])
        if newAtomName:
            newAtomName=newAtomName[:4]
            numLeadingSpaces = 4-len(line[12:16].lstrip())
            numLeadingSpaces -= max(0, numLeadingSpaces+len(newAtomName)-4)
            line = "%s%s%s%s%s" % (line[:12],
                                   numLeadingSpaces*' ',
                                   newAtomName,
                                   (4-len(newAtomName)-numLeadingSpaces)*' ',
                                   line[16:])
        s+='%s\n' % line
    return pdbParser.PDB(asString=s)



def parseAtomResidueMappingFile(filename):
    fIn=open(filename)
    mapping={}
    for line in fIn:
        resAtoms = line.strip().split()
        mapping[resAtoms[0]]=resAtoms[1]
    fIn.close()
    return mapping

        
def parseCommandLine():
    import getopt
    shortOpts = 'hd:r:a:o:'
    longOpts = ['help=', 'pdbIn=', 'paramIn=', 'prepIn=']
    opts, args_proper = getopt.getopt(sys.argv[1:], shortOpts, longOpts)
    pdbFilename = None
    paramFilename = None
    prepFilename = None
    ntPrepFilename = None
    ctPrepFilename = None
    for option, parameter in opts:
        if option=='-h': usageError()
        if option=='-d' or option=='--pdbIn': pdbFilename = parameter
        if option=='-a' or option=='--paramIn': paramFilename = parameter
        if option=='-r' or option=='--prepIn': prepFilename = parameter
        if option=='-n' or option=='--ntPrepIn': ntPrepFilename = parameter
        if option=='-c' or option=='--ctPrepIn': ctPrepFilename = parameter
    return (args_proper,
            pdbFilename,
            paramFilename,
            prepFilename,
            ntPrepFilename,
            ctPrepFilename)

def usageError():
    s = 'usage: %s' % os.path.basename(sys.argv[0])
    sBlanks = len(s)*' '
    sys.stdout.write('%s  --pdbIn=PDB_IN_FILENAME           \\\n' % s)
    sys.stdout.write('%s  --paramIn=PARM_IN_FILENAME        \\\n' % sBlanks)
    sys.stdout.write('%s  --prepIn=PREP_IN_FILENAME         \\\n' % sBlanks)
    sys.stdout.write('%s  --ntPrepIn=NTEMR_PREP_IN_FILENAME \\\n' % sBlanks)
    sys.stdout.write('%s  --ctPrepIn=CTEMR_PREP_IN_FILENAME\n' % sBlanks)
    sys.exit(1)

def main():
    (args_proper,
     pdbFilename,
     paramFilename,
     prepFilename,
     ntPrepFilename,
     ctPrepFilename) = parseCommandLine()

    if not pdbFilename or not paramFilename or not prepFilename:
         usageError()

    p=AmberSetup(pdbFilename,
                 paramFilename,
                 prepFilename,
                 ntPrepFilename,
                 ctPrepFilename)
    return


if __name__=='__main__':
    main()

