#
#
#

__author__ = "Randall J. Radmer"
__version__ = "1.0"
__doc__ = """Parse and process a PDB file or string.
Note that this is a simple, fragle parser.
A better parser can be found in the Biopython project"""


import sys
import re

class PDB:
    """PDB object"""
    def __init__(self, filename=None, asString=None):
        numArgsSet=0
        if filename:
            f = open(filename)
            self.pdbLines = f.readlines()
            f.close()
            numArgsSet+=1
        elif asString:
            self.pdbLines = []
            for line in asString.split('\n'):
                self.pdbLines.append(line+'\n')
            numArgsSet+=1
        else:
            raise Exceptin, 'No input for PDB object'
        if numArgsSet>1:
            raise Exceptin, 'Set only one of filename or asString args for PDB object'

        self.atomNums=[]
        self.atomNames=[]
        self.resNums=[]
        self.resNames=[]
        self.coords=[]
        self.atomTypes=[]
        self.chainBreaks=[0]
        self.resIndexList=[]
        self.resNameList=[]
        resCount=0
        lastResNum=None
        for line in self.pdbLines:
            if line.find('TER')==0 or \
               line.find('END')==0:
                newBreak = resCount+1
                if newBreak not in self.chainBreaks:
                    self.chainBreaks.append(newBreak)
            elif len(line)>=53 and line[:4]=='ATOM':
                atomName=self.getAtomName(line)
                self.atomNames.append(atomName)

                m=re.search('[A-Z]', atomName)
                if m: self.atomTypes.append(m.group())
                else: self.atomTypes.append('')

                self.atomNums.append(self.getAtomNum(line))
                self.resNames.append(self.getResName(line))
                self.resNums.append(self.getResNum(line))
                self.coords.append(self.getCoord(line))

                if lastResNum is not None and lastResNum!=self.resNums[-1]:
                    resCount+=1
                    self.resNameList.append(self.resNames[-2])
                self.resIndexList.append(resCount)
                lastResNum=self.resNums[-1]
        self.resNameList.append(self.resNames[-1])
        self.numAtomPnts = len(self.atomNums)
        self.minResNum = min(self.resNums)
        self.maxResNum = max(self.resNums)

        self.atomPntsByResNum={}
        for i in range(len(self.resNums)):
            resNum=self.resNums[i]
            try:
                self.atomPntsByResNum[resNum].append(i)
            except KeyError:
                self.atomPntsByResNum[resNum]=[i]



    def getNumAtoms(self):
        return self.numAtomPnts

    def getAtomNum(self, line):
        """Get Atom Number from an ATOM line"""
        return int(line[6:11])


    def getAtomName(self, line):
        """Get Atom Name from an ATOM line"""
        s=line[12:16].strip()
        if line[13]=='H':
            index='0123456789'.find(line[12])
            if index>=0:
                s=line[13:16].strip()+line[12]
        return s


    def getResName(self, line):
        """Get Residue Name from an ATOM line"""
        return line[17:21].strip()


    def getResChain(self, line):
        """Get Residue Name from an ATOM line"""
        return line[21:22].strip()


    def getResNum(self, line):
        """Get Residue Number from an ATOM line"""
        return int(line[22:26])


    def getCoord(self, line):
        """Get Coordinates Number from an ATOM line"""
        return (float(line[30:38]), float(line[38:46]), float(line[46:54]))


    def atomPntsInRes(self, resNum):
        """Return a sorted list of all atom pointers in a residue"""
        atomPntList = []
        for i in range(self.numAtomPnts):
            if self.resNums[i] == resNum:
                atomPntList.append(i)
        atomPntList.sort()
        return atomPntList


    def isHeavyAtom(self, atomPnt):
        """Return a Boolean value indicating if an atom is *not* hydrogen"""
        m = re.search('[A-Z]', self.atomNames[atomPnt])
        if m.group()=='H':
            return False
        else:
            return True


    def getFullResNameList(self):
        return self.getFullResNameNumList()[0]

    def getFullResNameNumList(self):
        """Return a list of the residue sequence"""
        resNameList=[]
        resNumList=[]
        lastResNum=-1
        for i in range(self.numAtomPnts):
            if self.resNums[i]!=lastResNum:
                resNameList.append(self.resNames[i])
                resNumList.append(self.resNums[i])
            lastResNum=self.resNums[i]
        return (resNameList, resNumList)


    def getUniqueResNameList(self):
        """Return a sorted list of residues types in the PDB file"""
        uResDict = {}
        for i in range(self.numAtomPnts):
            uResDict[self.resNames[i]]=True
        uResList = uResDict.keys()
        uResList.sort()
        return uResList


    def getChainPnts(self, atomNamePntDict):
        """Return a dictionary mapping residue number to specific atoms in each residue"""
        aDict = {}
        values = atomNamePntDict.values()
        noneList = (max(values)+1)*[None]
        for i in range(self.numAtomPnts):
            resNum = self.resNums[i]
            atomName = self.atomNames[i]
            if atomName=='CA':
                aDict[resNum] = noneList[:]
                aDict[resNum][atomNamePntDict[atomName]] = i
        resNumList = aDict.keys()
        for i in range(self.numAtomPnts):
            resNum = self.resNums[i]
            if resNum in resNumList:
                atomName = self.atomNames[i]
                if atomNamePntDict.has_key(atomName):
                     aDict[resNum][atomNamePntDict[atomName]] = i
        return aDict


    def getMainChainPnts(self):
        """Return a dictionary mapping residue number to the three main chain atoms in each residue"""
        atomNamePntDict = {'N' : 0,
                           'CA': 1,
                           'C' : 2}
        chainPnts = self.getChainPnts(atomNamePntDict)
        for i in range(self.numAtomPnts):
            resName = self.resNames[i]
            atomName = self.atomNames[i]
            if resName=='NHE' and atomName=='N':
                chainPnts[self.resNums[i]] = [i, None, None]
            if resName=='NME' and atomName=='N':
                chainPnts[self.resNums[i]] = [i, None, None]
            if resName=='ACE' and atomName=='C':
                chainPnts[self.resNums[i]] = [None, None, i]
        return chainPnts


    def getBackBonePnts(self):
        """Return a dictionary mapping residue number to the five backbone atoms in each residue"""
        atomNamePntDict = {'H' : 0,
                           'N' : 1,
                           'CA': 2,
                           'C' : 3,
                           'O' : 4}
        return self.getChainPnts(atomNamePntDict)


    def getSideChainPnts(self):
        """Return a dictionary mapping residue number to some of the sidechain atoms in each residue"""
        atomNamePntDict = {'CA': 0,
                           'CB': 1,
                           'CG': 2, 'CG1': 2, 'OG': 2, 'OG1': 2, 'SG': 2, 
                           'CD': 3, 'SD': 3,
                           'CE': 4, 'NE': 4,
                           'CZ': 5, 'NZ': 5}
        return self.getChainPnts(atomNamePntDict)


    def getResNumAtomNameMatch(self, matchResNum, matchAtomName):
        """Return an atom pointer to a specific atom name in a specific residue"""
        matchAtomNameUpper = matchAtomName.upper()
        for i in range(self.numAtomPnts):
            if self.resNums[i]==matchResNum \
               and self.atomNames[i]==matchAtomNameUpper:
                return i
        return None


    def getAtomPntByAtomNameResNum(self, atomName, resNum):
        atomPnt=None
        for ii in self.atomPntsByResNum[resNum]:
            if self.atomName[ii]==atomName:
                atomPnt=ii
                break
        return atomPnt


    def getAtomNameMatchList(self, matchAtomName='CA'):
        """Return a list of pointers to all atoms with a specific name"""
        atomPntList=[]
        matchAtomNameUpper=matchAtomName.upper()
        for i in range(self.numAtomPnts):
            if self.atomNames[i]==matchAtomNameUpper:
                atomPntList.append(i)
        return atomPntList


    def getResNameAtomNameMatchList(self, matchResName, matchAtomName):
        """Return a list of pointers to all atoms in a specific residue type that have a specific atom name"""
        atomPntList = []
        matchResNameUpper = matchResName.upper()
        matchAtomNameUpper = matchAtomName.upper()
        for i in range(self.numAtomPnts):
            if self.resNames[i]==matchResNameUpper \
               and self.atomNames[i]==matchAtomNameUpper:
                atomPntList.append(i)
        return atomPntList


    def getResNumListAtomNameMatchList(self, matchResNums, matchAtomName):
        """Return a list of pointers to all atoms with a specific name that are in one of a list of residue numbers"""
        atomPntList=[]
        matchAtomNameUpper = matchAtomName.upper()
        for i in range(self.numAtomPnts):
            if self.atomNames[i]==matchAtomNameUpper \
               and self.resNums[i] in matchResNums:
                atomPntList.append(i)
        return atomPntList

    def getResNumListAtomTypeMatchList(self, matchResNums, matchAtomType, excludedResList=[]):
        """Return a list of pointers to all atoms of a specific type that are in one of a list of residue numbers"""
        atomPntList=[]
        matchAtomTypeUpper = matchAtomType.upper()
        for i in range(self.numAtomPnts):
            if self.atomTypes[i]==matchAtomTypeUpper \
               and self.resNums[i] in matchResNums \
               and (self.resNames[i] not in excludedResList):
                atomPntList.append(i)
        return atomPntList
        

    def buildNewPDB(self, coords, addRemarks='', useNM=False, removeEND=False):
        """Return a new PDB file (as a string) with modified coordinates"""
        s = addRemarks
        atomNum = 0
        for line in self.pdbLines:
            if removeEND and line.strip().find('END')==0: continue
            if len(line)>=54 and line[:4]=='ATOM':
                x, y, z = coords[atomNum]
                if useNM:
                    x*=10
                    y*=10
                    z*=10
                s += '%s%8.3f%8.3f%8.3f\n' % (line[:30], x, y, z)
                atomNum += 1
            else:
                s += line
        return s


    def getCoords(self, useNM=False):
        if useNM:
            coords=[]
            for coord in self.coords:
                coords.append( (0.1*coord[0],
                                0.1*coord[1],
                                0.1*coord[2]) )
            return coords
        else:
            return self.coords


    def checkCaChirality(self):
        """Return a list of residue numbers with bad chirality"""
        import math
        import chemStruct
        bbPnts = self.getBackBonePnts()
        scPnts = self.getSideChainPnts()
        resNums=bbPnts.keys()
        resNums.sort()
        badResList=[]
        chiCount=0
        chiTotal=0.0
        chiTotal2=0.0
        for resNum in resNums:
            pntH, pntN, pntCA, pntC, pntO = bbPnts[resNum]
            pntCB = scPnts[resNum][1]
            if pntCB is None: continue
            chi = chemStruct.getChirality(self.coords[pntCA],
                                          self.coords[pntN],
                                          self.coords[pntC],
                                          self.coords[pntCB])
            chiCount+=1
            chiTotal+=chi
            chiTotal2+=chi**2
            if chi<0:
#                print "Chirality for residue %d is %.1f" % (resNum, chi)
                badResList.append(resNum)

        n=chiCount
        ave=chiTotal/n
        rms=math.sqrt(chiTotal2/n-ave**2)
        return (badResList, (n, ave, rms))



