#
#
#

__author__ = "Magdalena A. Jonikas"
__version__ = "1.0"
__date__ = " 04/08/09 "
__doc__ = """ Functions and classes needed for C2A """

import sys, math, time, random, copy

import simtk.utils.rmsd as rms
import simtk.utils.vecOps as vecOps

def parseFrames(templateTrace):
    frameList = [] # Each item is an entire frame
    templateData = open(templateTrace)
    line = templateData.readline()
    frame = []
    frame.append(line)
    while line:
        frame.append(line)
        if line[:3]=='END' or line[:3]=='TER':
            frameList.append(frame)
            frame=[]
        line = templateData.readline()
    return frameList

def calcDist(pos1,pos2):
    (x1,y1,z1)=pos1
    (x2,y2,z2)=pos2
    distance = ((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2)**0.5
    return distance

def calcRMSd(coords, coordsRef):
    return rms.fast_rmsd(list(coordsRef),
                         list(coords),
                         transposeCoordsRef=True,
                         transposeCoords=True)

def calcTransRot(coords, coordsRef):
    #print 'hello'
    #print coords
    #print coordsRef
    #sys.exit(1)
    try: (centroid0,
        deltaCentroid0,
        rot0,
        rmsd)=rms.calculate_rotation_rmsd(coordsRef=coordsRef,
                                       coords=coords,
                                       transposeCoordsRef=True,
                                       transposeCoords=True)
    except SegmentationFault:
        print "oups"
        print err
        sys.exit(1)
    #print rot0
    #sys.exit(1)
    rot=vecOps.transposeMat3(rot0)
    #print rot
    centroidRef0=vecOps.sum3(deltaCentroid0, centroid0)
    tran=vecOps.delta3(centroidRef0, vecOps.multMat3Vec3(rot0, centroid0))

    return (tran, rot)

def calcRMSdFromTransRot(coords, tran,rot):
    M=[]
    for v in vecOps.multMatMat(vecOps.transposeMat(coords), rot):
        M.append( (v[0]+tran[0],
                   v[1]+tran[1],
                   v[2]+tran[2]) )
    return vecOps.transposeMat(M)

def getOptions(name, structResList, modelStruct, searchStructure, templateName, frameIndex, convStructure):
    resList = []
    for list in structResList:
        resList.extend(list)
    if len(structResList)==2:
        optionsListH1 = searchStructure.findSingleStrandMatches(structResList[0], modelStruct, searchStructure.dhfudge)
        optionsListH2 = searchStructure.findSingleStrandMatches(structResList[1], modelStruct, searchStructure.dhfudge)
        optionsList = searchStructure.findDoubleStrandMatches(structResList, modelStruct, optionsListH1, optionsListH2)
    else:
        optionsList = searchStructure.findSingleStrandMatches(resList, modelStruct, searchStructure.dfudge)
    modelCoordsRef = modelStruct.getResListCoords(resList)
    structSeq = modelStruct.getResListNames(resList)
    optionsStatsOut = open('%s-%i-%s-stats.dat' % (templateName, frameIndex, name), 'w')
    optionLib = parseOptions(optionsList, searchStructure, modelCoordsRef, convStructure, optionsStatsOut, name, resList, structSeq)
    optionsStatsOut.close()
    lib={}
    if name[0]=='H':
        if len(optionLib)>=10: nkeep = 10
        else:
            nkeep = len(optionLib)
            print "  Warning, keeping only %i options for %s" % (nkeep, name)
    elif len(optionLib)>=200: nkeep = 200
    else:
        nkeep = len(optionLib)
        print "  Warning, keeping only %i options for %s" % (nkeep, name)
    for i in range(nkeep):
        (rmsd, newCoords, altResList) = optionLib[i]
        lib[i]=(newCoords, altResList)
    return lib

def parseOptions(optionsList, searchStructure, modelCoordsRef, convStructure, optionsStatsOut, name, structResList, structSeq):
    optionIndex = 0
    optionLib=[]
    #optionsList = [optionsList[41]]
    for option in optionsList:
        if len(option)==3:
            (resi, resj, d) = option
            altResList = range(resi, resj+1)
        else:
            (resA, resB, resC, resD) = option
            if resA<resB: altResList = range(resA, resB+1)
            else: altResList = range(resB, resA+1)
            if resC<resD: altResList.extend(range(resC, resD+1))
            else: altResList.extend(range(resD, resC+1))
        altSeq = searchStructure.getSequence(altResList)
        coords = searchStructure.fullAtomic.getResListCoords(altResList)
        coarseCoords = searchStructure.getResListCoords(altResList)
        (trans, rot) = calcTransRot(coarseCoords, modelCoordsRef)
        coordsTR = calcRMSdFromTransRot(coarseCoords, trans, rot)
        rmsd = calcRMSd(coordsTR, modelCoordsRef)
        fullAtomicCoordsTR = calcRMSdFromTransRot(coords, trans, rot)
        altCoords = searchStructure.fullAtomic.getAtomicCoordsLib(altResList, fullAtomicCoordsTR)
        newCoords = convStructure.fixPiece(structSeq, structResList, altSeq, altResList, altCoords)
        if newCoords:
            optionsStatsOut.write('%.2f\n' % (rmsd))
            optionLib.append((rmsd, newCoords, altResList))
            optionIndex+=1
    optionLib.sort()
    return optionLib

def checkCollisions(lines, piecesObject, d1Cut, d2Cut, centerAtoms):
    coarseStructure = Coarse(lines, typelist=centerAtoms)
    atomicStructure = FullAtomic(lines)
    badCount = 0
    badRes = []
    resList = coarseStructure.res2type2index.keys()
    for i in range(len(resList)):
        resi = resList[i]
        for j in range(i+1,len(resList)):
            resj = resList[j]
            distance = coarseStructure.checkClose(resi, resj)
            if distance <= d1Cut:
                bad = atomicStructure.checkCollision(resi, resj, d2Cut)
                if bad:
                    badRes.append(resi)
                    badRes.append(resj)
                    badCount+=bad
    badPieces = []
    for res in badRes:
        pieceID = piecesObject.res2piece[res]
        if pieceID not in badPieces: badPieces.append(pieceID)
    return badCount, badPieces

class Helix:
    def __init__(self, line, pieces):
        if line[0]!='H':
            print "Error, this is not a helix"
            print line
            sys.exit(1)
        cols0 = line.split()
        self.name = cols0[0]
        cols1 = cols0[1].split(',')
        H1 = cols1[0].split(':')
        H2 = cols1[1].split(':')
        self.H1start = int(H1[0])
        self.H1end = int(H1[1])
        self.H2start = int(H2[0])
        self.H2end = int(H2[1])
        self.H1resList = range(self.H1start, self.H1end+1)
        self.H2resList = range(self.H2start, self.H2end+1)
        self.resList = self.H1resList+self.H2resList
        for res in self.resList:
            pieces.res2name[res]=self.name
        self.longResList = self.resList

    def printHelixInfo(self):
        print "5' end ", self.H1resList
        self.H2resList.reverse()
        print "3' end ", self.H2resList
        self.H2resList.reverse()

    def printLongResList(self):
        print self.name, self.resList

class Loop:
    def __init__(self, line, pieces):
        if line[0]!='L':
            print "Error, this is not a loop"
            print line
            sys.exit(1)
        cols = line.split()
        self.name = cols[0]
        self.helixName = cols[1]
        self.helix = pieces.helices[self.helixName]
        self.shortStart = self.helix.H1end+1
        self.shortEnd = self.helix.H2start-1
        self.shortResList = range(self.shortStart, self.shortEnd+1)
        self.o1ResList = range(self.shortStart-1, self.shortEnd+2)
        self.resList = self.shortResList
        self.longStart = self.helix.H1start
        self.longEnd = self.helix.H2end
        self.longResList = range(self.longStart, self.longEnd+1)
        for res in self.shortResList:
            pieces.res2name[res]=self.name

    def printLoopInfo(self):
        print "5' end ", self.helix.H1resList
        print "loop          ", self.shortResList
        self.helix.H2resList.reverse()
        print "3' end ", self.helix.H2resList
        self.helix.H2resList.reverse()

    def printLongResList(self):
        print self.name, self.longResList

class End:
    def __init__(self, line, pieces):
        cols = line.split()
        self.name = cols[0]
        cols = cols[1].split(':')
        helix = cols[0]
        self.end = cols[1][0]
        length = int(cols[1][1:])
        h = pieces.helices[helix]
        if self.end=='5': 
            if length>0: 
                self.resList = range(h.H1end+1, h.H1end+length+1)
                self.o1ResList = range(h.H1end, h.H1end+length+1)
            else: 
                self.resList = range(h.H1start+length, h.H1start)
                self.o1ResList = range(h.H1start+length, h.H1start+1)
            self.overlap = h.H1resList
            self.longResList = h.H1resList+self.resList
            self.longResList.sort()
        else:
            if length>0: 
                self.resList = range(h.H2end+1, h.H2end+length+1)
                self.o1ResList = range(h.H2end, h.H2end+length+1)
            else: 
                self.resList = range(h.H2start+length, h.H2start)
                self.o1ResList = range(h.H2start+length, h.H2start+1)
            self.overlap = h.H2resList
            self.longResList = h.H2resList+self.resList
            self.longResList.sort()
        for res in self.resList:
            pieces.res2name[res]=self.name
        self.shortResList = self.resList

    def printEndInfo(self):
        print "end", self.resList
        print "helix    ", self.overlap

    def printLongResList(self):
        print self.name, self.longResList

        
class Junction:
    def __init__(self, line, pieces):
        cols = line.split()
        self.name = cols[0]
        helixNames = cols[1].split(',')
        helix1 = helixNames[0].split(':')
        self.H1name = helix1[0]
        self.H1side = helix1[1]
        helix2 = helixNames[1].split(':')
        self.H2name = helix2[0]
        self.H2side = helix2[1]

        if self.H1side=='5': self.end5resList = pieces.helices[self.H1name].H1resList
        else: self.end5resList = pieces.helices[self.H1name].H2resList
        if self.H2side=='5': self.end3resList = pieces.helices[self.H2name].H1resList
        else: self.end3resList = pieces.helices[self.H2name].H2resList

        self.shortResList = range(self.end5resList[-1]+1, self.end3resList[0])
        if len(self.shortResList)==1:
            self.o1ResList = range(self.end5resList[-1]-2, self.end3resList[0]+3)
        elif len(self.shortResList)==2:
            self.o1ResList = range(self.end5resList[-1]-1, self.end3resList[0]+2)
        else: 
            self.o1ResList = range(self.end5resList[-1], self.end3resList[0]+1)
        self.resList = self.shortResList
        self.longResList = self.end5resList+self.shortResList+self.end3resList
        self.overlapList = []
        for i in range(len(self.longResList)):
            res = self.longResList[i]
            if res not in self.resList:
                self.overlapList.append(i)

        for res in self.shortResList:
            pieces.res2name[res]=self.name

    def printJunctionInfo(self):
        print "5' end ", self.end5resList
        print "junction        ", self.shortResList
        print "3' end ", self.end3resList

    def printLongResList(self):
        print self.name, self.longResList


class Pieces:
    def __init__(self, dataIn):
        self.helices = {}
        self.junctions = {}
        self.loops = {}
        self.ends = {}
        self.tertiary = {}
        self.res2name = {}
        for line in dataIn:
            if line[0]=='H': 
                thisHelix = Helix(line, self)
                self.helices[thisHelix.name] = thisHelix
            if line[0]=='L':
                thisLoop = Loop(line, self)
                self.loops[thisLoop.name] = thisLoop
            if line[0]=='J':
                thisJunction = Junction(line, self)
                self.junctions[thisJunction.name] = thisJunction
            if line[0]=='E':
                thisEnd = End(line, self)
                self.ends[thisEnd.name]=thisEnd
            if line[0]=='T':
                thisTert = Tertiary(line, self)
                self.tertiary[thisTert.name]=thisTert

class Coarse:
    def __init__(self,input, typelist=['C3*']):
        self.f = input
        self.typelist = typelist
        self.readPDB() 
        self.resList = self.res2type2index.keys()
        self.resList.sort()
        self.firstRes = self.resList[0]
        self.lastRes = self.resList[-1]
        self.n = len(self.resList)
        #self.makeBreakList()
        
    def makeBreakList(self):
        self.breakList = []
        for res in self.resList:
            if res != self.firstRes:
                if res != lastRes+1: self.breakList.append(lastRes)
            lastRes = res
            
    def readPDB(self):
        self.coords=([], [], [])
        self.res2type2coords={}
        self.res2type2index={}
        self.res2name = {}
        for line in self.f:
            if (line[13:16] in self.typelist): type = line[13:16]
            #elif (line[12:15] in self.typelist): type = line[12:15]
            else: continue
            res = int(line[22:26])
            resName = line[17:20]
            resName = resName.strip()
            if res not in self.res2name.keys(): self.res2name[res]=resName
            if res not in self.res2type2coords.keys(): self.res2type2coords[res]={}
            if type not in self.res2type2coords[res].keys():
                x = float(line[30:38])
                y = float(line[38:46]) 
                z = float(line[46:54])
                self.res2type2coords[res][type] = (x,y,z)
            #else: print "more than one %s for residue %i, keeping only first" % (type, res)
        index = 0
        for res in self.res2type2coords.keys():
            if len(self.res2type2coords[res].keys()) != len(self.typelist): continue 
            for type in self.typelist:
                #if type not in self.res2type2coords[res].keys(): continue
                (x,y,z)=self.res2type2coords[res][type]
                if res not in self.res2type2index.keys(): self.res2type2index[res]={}
                if type not in self.res2type2index[res].keys(): self.res2type2index[res][type]=index
                else: print "more than one %s for residue %i, keeping only first" % (type, res)
                self.coords[0].append(x)
                self.coords[1].append(y)
                self.coords[2].append(z)
                index+=1

    def makeSearchLibs(self):
        self.p2d = {}
        self.p2l = {}
        self.p2b = {}
        for i in range(self.n):
            resi = self.resList[i]
            posi = self.getResCoords(resi)[0]
            for j in range(i+1, self.n):
                resj = self.resList[j]
                posj = self.getResCoords(resj)[0]
                d = calcDist(posi, posj)
                l = abs(resj-resi+1)
                if abs(j-i+1)==l: b=0
                else: b=1
                self.p2d[(resi, resj)] = d
                self.p2d[(resj, resi)] = d
                self.p2l[(resi, resj)] = l
                self.p2b[(resi, resj)] = b

    def getResCoords(self, res):
        coordsList = []
        for type in self.res2type2index[res]:
            for type in self.typelist:
                index=self.res2type2index[res][type]
                x=(self.coords[0][index])
                y=(self.coords[1][index])
                z=(self.coords[2][index])
                coordsList.append( (x,y,z) )
        return coordsList

    def findSingleStrandMatches(self, refResList, refStruct, fudge):
        refResList.sort()
        resStart = refResList[0]
        resEnd = refResList[-1]
        L = resEnd-resStart+1
        coordsStart = refStruct.getResCoords(resStart)[0]
        coordsEnd = refStruct.getResCoords(resEnd)[0]
        D = calcDist(coordsStart, coordsEnd)
        options = []
        count = 0
        for i in range(self.n):
            count+=1
            resi = self.resList[i]
            resj = resi+L-1
            if resj not in self.resList: continue
            if self.p2b[(resi, resj)]: continue
            d = self.p2d[(resi, resj)]
            if d<=(D*(1+fudge)) and d>=(D*(1-fudge)):
                options.append((resi, resj, d))
        return options

    def findDoubleStrandMatches(self, refResList, refStruct, listH1, listH2):
        resa = refResList[0][0]
        resb = refResList[0][-1]
        resc = refResList[1][0]
        resd = refResList[1][-1]
        dac = refStruct.checkClose(resa, resc)
        dad = refStruct.checkClose(resa, resd)
        dbc = refStruct.checkClose(resb, resc)
        dbd = refStruct.checkClose(resb, resd)
        options = []
        for resA, resB, d1 in listH1:
            for resC, resD, d2 in listH2:
                if resA==resC or resA==resD or resB==resC or resB==resD: continue
                dAC = self.p2d[(resA,resC)]
                dAD = self.p2d[(resA,resD)]
                dBC = self.p2d[(resB,resC)]
                dBD = self.p2d[(resB,resD)]
                if dAC<=(dad*(1+self.dhfudge)) and dAC>=(dad*(1-self.dhfudge)) and \
                   dBD<=(dbc*(1+self.dhfudge)) and dBD>=(dbc*(1-self.dhfudge)):
                    options.append((resA, resB, resD, resC))
                if dAD<=(dad*(1+self.dhfudge)) and dAD>=(dad*(1-self.dhfudge)) and \
                   dBC<=(dbc*(1+self.dhfudge)) and dBC>=(dbc*(1-self.dhfudge)):
                    options.append((resA, resB, resC, resD))
        return options

    def checkClose(self, resi, resj):
        coordsi = self.getResCoords(resi)[0]
        coordsj = self.getResCoords(resj)[0]
        distance = calcDist(coordsi, coordsj)
        return distance

    def getResListCoords(self, resList):
        coordsList = ([],[],[])
        for res in resList:
            for type in self.res2type2index[res]:
                for type in self.typelist:
                    index=self.res2type2index[res][type]
                    x=(self.coords[0][index])
                    y=(self.coords[1][index])
                    z=(self.coords[2][index])
                    coordsList[0].append(x)
                    coordsList[1].append(y)
                    coordsList[2].append(z)
        return coordsList

    def getResListNames(self, resList):
        nameList = []
        for res in resList:
            nameList.append(self.res2name[res])
        return nameList

    def getSequence(self, resList):
        sequence = []
        for res in resList:
            sequence.append(self.res2name[res])
        return sequence

class FullAtomic:
    def __init__(self, input):
        self.f = input
        self.readPDB()
    def readPDB(self):
        self.res2atoms = {}
        self.res2resName = {}
        self.res2type2coords = {}
        self.res2type2atomindex = {}
        self.res2type2line={}
        self.atom2index = {}
        self.atom2type = {}
        self.atom2resName = {}
        self.res2type2atom = {}
        index=0
        self.coords=([],[],[])
        for line in self.f:
            if line.find('ATOM')==0:
                res = int(line[22:26])
                if index==0: self.firstRes = res
                atom = int(line[4:11])
                resName = line[19:20]
                type = line[13:16]
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                if res not in self.res2atoms.keys(): self.res2atoms[res]=[]
                if res not in self.res2resName.keys(): self.res2resName[res]=resName
                if res not in self.res2type2atom.keys(): self.res2type2atom[res]={}
                if res not in self.res2type2line.keys(): self.res2type2line[res]={}
                if res not in self.res2type2coords.keys(): self.res2type2coords[res]={}
                if res not in self.res2type2atomindex.keys(): self.res2type2atomindex[res]={}
                self.res2type2coords[res][type]= (x,y,z)
                self.res2type2atomindex[res][type]= atom
                self.res2atoms[res].append(atom)
                self.atom2type[atom]=type
                self.atom2resName[atom]=resName
                self.atom2index[atom]=index
                self.res2type2atom[res][type]=atom
                self.res2type2line[res][type]=line[11:19]
                self.coords[0].append( x )
                self.coords[1].append( y )
                self.coords[2].append( z )
                index+=1
            if line.find('END')==0 or line.find('TER')==0:
                break
        self.lastRes = res

    def getResListCoords(self, resList):
        coordsList = ([],[],[])
        for res in resList:
            atomList = self.res2atoms[res]
            for atom in atomList:
                coordsList[0].append(self.coords[0][self.atom2index[atom]])
                coordsList[1].append(self.coords[1][self.atom2index[atom]])
                coordsList[2].append(self.coords[2][self.atom2index[atom]])
        return coordsList

    def getAtomicCoordsLib(self, resList, coords):
        coordsLib = {}
        index = 0
        for res in resList:
            if res not in coordsLib.keys(): coordsLib[res]={}
            atomList = self.res2atoms[res]
            for atom in atomList:
                type = self.atom2type[atom]
                x = coords[0][index]
                y = coords[1][index]
                z = coords[2][index]
                coordsLib[res][type] = (x,y,z)
                index+=1
        return coordsLib

    def checkCollision(self, resi, resj, dCut):
        bad = 0
        atomsi = self.res2atoms[resi]
        atomsj = self.res2atoms[resj]
        for iatom in atomsi:
            icoords = self.getAtomCoords(iatom)
            for jatom in atomsj:
                jcoords = self.getAtomCoords(jatom)
                dist = calcDist(icoords, jcoords)
                if dist < dCut: bad += 1
        return bad

    def getAtomCoords(self, atom):
        index = self.atom2index[atom]
        x = self.coords[0][index]
        y = self.coords[1][index]
        z = self.coords[2][index]
        return (x, y, z)

class ConvStructure:
    def __init__(self, primarySeq, pieces, baseData):
        self.setPrimary(primarySeq)
        self.parsePieces(pieces)
        self.keepAtomList = ['P  ','O1P','O2P','O5*','C5*','C4*','O4*','C3*','O3*','C2*','O2*','C1*']
        self.base2atoms = {}
        self.base2atoms['A']=['P  ','O1P','O2P','O5*','C5*','C4*','O4*','C3*','O3*','C2*','O2*','C1*','N9 ','C8 ','N7 ','C5 ','C6 ','N6 ','N1 ','C2 ','N3 ','C4 ']
        self.base2atoms['G']=['P  ','O1P','O2P','O5*','C5*','C4*','O4*','C3*','O3*','C2*','O2*','C1*','N9 ','C8 ','N7 ','C5 ','C6 ','O6 ','N1 ','C2 ','N2 ','N3 ','C4 ']
        self.base2atoms['C']=['P  ','O1P','O2P','O5*','C5*','C4*','O4*','C3*','O3*','C2*','O2*','C1*','N1 ','C2 ','O2 ','N3 ','C4 ','N4 ','C5 ','C6 ']
        self.base2atoms['U']=['P  ','O1P','O2P','O5*','C5*','C4*','O4*','C3*','O3*','C2*','O2*','C1*','N1 ','C2 ','O2 ','N3 ','C4 ','O4 ','C5 ','C6 ']
        self.baseData = baseData
        self.setTypeList()

    def setPrimary(self, primarySeq):
        self.primarySeq = []
        for line in primarySeq:
            self.primarySeq.append(line.strip())

    def parsePieces(self, pieces):
        self.pieces = {}
        self.pieceSeq = {}
        for helix in pieces.helices.keys():
            h = pieces.helices[helix]
            self.getPieceInfo(h)
        for junction in pieces.junctions.keys():
            j = pieces.junctions[junction]
            self.getPieceInfo(j)
        for loop in pieces.loops.keys():
            l = pieces.loops[loop]
            self.getPieceInfo(l)
        for end in pieces.ends.keys():
            e = pieces.ends[end]
            self.getPieceInfo(e)
        for tert in pieces.tertiary.keys():
            t = pieces.tertiary[tert]
            self.getPieceInfo(t)

    def setTypeList(self):
        self.type2list={}
        self.type2res0={}
        self.type2resA={}
        self.type2resB={}
        self.type2list['A'] = ['N9 ','C4 ','N3 ','C2 ','N1 ','C6 ','N6 ','C5 ','N7 ','C8 ']
        self.type2list['G'] = ['N9 ','C4 ','N3 ','C2 ','N2 ','N1 ','C6 ','O6 ','C5 ','N7 ','C8 ']
        self.type2list['C'] = ['N1 ','C2 ','O2 ','N3 ','C4 ','N4 ','C5 ','C6 ']
        self.type2list['U'] = ['N1 ','C2 ','O2 ','N3 ','C4 ','O4 ','C5 ','C6 ']
        self.type2res0['A']='N9 '
        self.type2resA['A']='N1 '
        self.type2resB['A']='C5 '
        self.type2res0['G']='N9 '
        self.type2resA['G']='N1 '
        self.type2resB['G']='C5 '
        self.type2res0['C']='N1 '
        self.type2resA['C']='N3 '
        self.type2resB['C']='C5 '
        self.type2res0['U']='N1 '
        self.type2resA['U']='N3 '
        self.type2resB['U']='C5 '

    def getPieceInfo(self, piece):
        name = piece.name
        self.pieces[name]=piece.longResList
        seqList = []
        for res in piece.longResList:
            seqList.append(self.primarySeq[res-1])
        self.pieceSeq[name]=seqList

    def fixPiece(self, structSeq, structResList, altSeq, altResList, altCoordLib):
        newCoordLib = {}
        for resIndex in range(len(structSeq)):
            structResName = structSeq[resIndex]
            altResName = altSeq[resIndex]
            res = structResList[resIndex]
            newCoordLib[res]={}
            if structResName==altResName:
                # keep the same coordinates for this residue
                newCoordLib[res]={}
                for atomType in altCoordLib[altResList[resIndex]].keys():
                    if atomType in self.base2atoms[structResName]:
                        newCoordLib[res][atomType]=altCoordLib[altResList[resIndex]][atomType]
            else:
                # change the coordinates appropriately
                fixedRes = self.fixRes(res, structResName, altResName, altCoordLib[altResList[resIndex]])
                if fixedRes:
                    newCoordLib[res]=fixedRes
                else: return
            # Check the number of atoms in this residue
            goodList = self.base2atoms[structResName]
            goodList.sort()
            checkList = newCoordLib[res].keys()
            checkList.sort()
            if goodList != checkList:
                #print 'Bad list of atoms', structResName, altResName, res
                #print goodList
                #print checkList
                return
        return newCoordLib

    def fixRes(self, structResNum, structResName, altResName, altCoords):
        newCoords = {}
        atomList = altCoords.keys()
        for atomType in atomList:
            if atomType in self.keepAtomList:
                # this is a position that does not need to be changed
                (x,y,z) = altCoords[atomType]
                newCoords[atomType] = (x,y,z)
        # Correct all the other atoms
        object = self.makeHanger(altResName, altCoords)
        if object:
            (trans,hanger) = object
            list = self.addNewBase(structResName, trans, hanger)
            for atomType, x, y, z in list:
                newCoords[atomType]=(x,y,z)
            return newCoords
        else:
            #print 'structResNum: ', structResNum
            #print 'structResName: ', structResName
            #print 'altResName: ', altResName
            return
            sys.exit(1)

    def makeHanger(self, resType, altCoords):
        atomTypeList = self.type2list[resType]
        atomList = altCoords.keys()
        atomType2coords = {}
        for atomType in atomList:
            if atomType == self.type2res0[resType] or atomType == self.type2resA[resType] or atomType == self.type2resB[resType]:
                (x,y,z) = altCoords[atomType]
                atomType2coords[atomType]=(x,y,z)
        if len(atomType2coords.keys())==3:
            resA = vecOps.delta3(atomType2coords[self.type2resA[resType]],
                                 atomType2coords[self.type2res0[resType]])
            resB = vecOps.delta3(atomType2coords[self.type2resB[resType]],
                                 atomType2coords[self.type2res0[resType]])
            NZ = vecOps.crossProd3(resA,resB)
            NZ=vecOps.unitVec3(NZ)
            NX = vecOps.unitVec3(resA)
            NY = vecOps.crossProd3(NZ,NX)
            hanger = ((NX[0], NY[0], NZ[0]),
                      (NX[1], NY[1], NZ[1]),
                      (NX[2], NY[2], NZ[2]))
            trans = atomType2coords[self.type2res0[resType]]
            return (trans, hanger)
        else:
            #print "Something wrong with this res"
            return
            sys.exit(1)


    def addNewBase(self, resType, trans, hanger):
        atomTypeList = self.type2list[resType]
        baseData = self.baseData[resType]
        list = []
        for atomType in atomTypeList:
            (x,y,z) = self.baseData[resType][atomType]
            P1 = (x,y,z)
            P2 = vecOps.multMat3Vec3(hanger, P1)
            P3 = vecOps.sum3(P2, trans)
            (x,y,z) = P3
            list.append( (atomType, x, y, z) )
        return list

#    def getHanger(self, res, resType):
#        atomTypeList = self.type2list[resType]
#        atomList = self.res2atoms[res]
#        atomType2coords = {}
#        for atom in atomList:
#            atomType = self.atom2type[atom]
#            if atomType in atomTypeList:
#                coords = self.getAtomCoords(atom)
#                atomType2coords[atomType]=coords
#        if len(atomType2coords.keys())==len(atomTypeList):
#            resA = atomType2coords[self.type2resA[resType]]-atomType2coords[self.type2res0[resType]]
#            resB = atomType2coords[self.type2resB[resType]]-atomType2coords[self.type2res0[resType]]
#            NZ = vecOps.crossProd3(resA,resB)
#            NZ=vecOps.unitVec3(NZ)
#            NX = vecOps.scale3(resA, (math.sqrt(vecOps.dotProd3(resA,resA))))
#            NY = vecOps.crossProd3(NZ,NX)
#            NX.shape = (3,1)
#            NY.shape = (3,1)
#            NZ.shape = (3,1)
#            hanger = ((NX[0], NY[0], NZ[0]),
#                      (NX[1], NY[1], NZ[1]),
#                      (NX[2], NY[2], NZ[2]))
#            trans = atomType2coords[self.type2res0[resType]]
#            print "NX: %8.3f %8.3f %8.3f" % (NX[0], NX[1], NX[2])
#            print "NY: %8.3f %8.3f %8.3f" % (NY[0], NY[1], NY[2])
#            print "NZ: %8.3f %8.3f %8.3f" % (NZ[0], NZ[1], NZ[2])
#            print "trans: %8.3f %8.3f %8.3f" % (trans[0], trans[1], trans[2])
#            print "hanger0: %8.3f %8.3f %8.3f" % (hanger[0][0], hanger[0][1], hanger[0][2])
#            print "hanger1: %8.3f %8.3f %8.3f" % (hanger[1][0], hanger[1][1], hanger[1][2])
#            print "hanger2: %8.3f %8.3f %8.3f" % (hanger[2][0], hanger[2][1], hanger[2][2])
#            return (trans, hanger)
#        else:
#            print "something wrong with this res", res
#            #print atomType2coords.keys()
#            sys.exit(1)

class AssemblyPieces:
    def __init__(self, piecesDefFile, piecesLib):
        piece2max = {}
        for line in piecesDefFile:
            cols = line.split()
            name = cols[0]
            max = int(cols[-1])
            piece2max[name]=max
        piecesKeys = piecesLib.keys()
        for piece in piecesKeys:
            optionsKeys = piecesLib[piece].keys()
            optionsKeys.sort()
            #if piece[0]=='H': maxOptions = piece2max[piece[:-2]]
            #else: maxOptions = piece2max[piece]
            maxOptions = piece2max[piece]
            for key in optionsKeys[maxOptions:]:
                del piecesLib[piece][key]
        self.piecesLib = piecesLib
        self.parsePieces(piecesDefFile)
        self.numPieces = len(self.pieces.keys())

    def parsePieces(self, piecesDefFile):
        pieces = Pieces(piecesDefFile)
        self.pieces = {}
        self.nonHelices = []
        self.helices = []
        for helix in pieces.helices.keys():
            resList = pieces.helices[helix].resList
            self.helices.extend(resList)
            self.pieces[helix]=resList
        for loop in pieces.loops.keys():
            reslist = pieces.loops[loop].shortResList
            self.nonHelices.extend(reslist)
            self.pieces[loop] = reslist
        for junction in pieces.junctions.keys():
            reslist = pieces.junctions[junction].shortResList
            self.nonHelices.extend(reslist)
            self.pieces[junction] = pieces.junctions[junction].shortResList
        for end in pieces.ends.keys():
            reslist = pieces.ends[end].shortResList
            self.nonHelices.extend(reslist)
            self.pieces[end]=reslist
        self.res2piece = {}
        for p in self.pieces.keys():
            resList = self.pieces[p]
            for res in resList:
                self.res2piece[res] = p

class CombinedStructure:
    def __init__(self, piecesDefFile, p2f2r, coarseModelIN, cutoff, centerAtoms):
        self.piecesObject = AssemblyPieces(piecesDefFile, p2f2r)
        self.refStruct = Coarse(coarseModelIN, typelist=centerAtoms)
        self.numSteps = 1
        self.t = 1
        self.coarseCut = 20.0
        self.fineCut = cutoff
        self.goal=0
        self.steps=100
        self.t0=time.time()

    def randomChoice(self):
        piecesChoice = {}
        piecesList = self.piecesObject.piecesLib.keys()
        for p in piecesList:
            choice = random.choice(self.piecesObject.piecesLib[p].keys())
            piecesChoice[p] = choice
        return piecesChoice

    def assemblePieces(self, piecesChoice):
        structureCoords = {}
        for pieceID in piecesChoice.keys():
            choiceID = piecesChoice[pieceID]
            TRfragmentCoords = self.piecesObject.piecesLib[pieceID][choiceID][0]
            for resID in TRfragmentCoords.keys():
                if resID in structureCoords.keys():
                    if resID not in self.piecesObject.nonHelices and pieceID[0]=='H':
                        del structureCoords[resID]
                        structureCoords[resID] = TRfragmentCoords[resID]
                else: 
                    structureCoords[resID] = TRfragmentCoords[resID]
        return self.printStructure(structureCoords)

    def printStructure(self, structureCoords):
        resIndex = 0
        atomIndex = 0
        lines = []
        resIDlist = structureCoords.keys()
        resIDlist.sort()
        for resID in resIDlist:
            resIndex+=1
            atomTypeList = structureCoords[resID].keys()
            atomTypeList.sort()
            for atomType in atomTypeList:
                atomIndex+=1
                (x,y,z) = structureCoords[resID][atomType]
                resName = self.refStruct.res2name[resID]
                lines.append( 'ATOM %6d  %s   %s %5d    %8.3f%8.3f%8.3f\n' % (atomIndex,atomType,resName,resIndex,x,y,z) )
        lines+='END\n'
        return lines

    def checkCollisions(self, lines, centerAtoms):
        coarseStructure = Coarse(lines, typelist=centerAtoms)
        atomicStructure = FullAtomic(lines)
        badCount = 0
        badRes = []
        resList = coarseStructure.res2type2index.keys()
        for i in range(len(resList)):
            resi = resList[i]
            for j in range(i+1,len(resList)):
                resj = resList[j]
                distance = coarseStructure.checkClose(resi, resj)
                if distance <= self.coarseCut:
                    bad = atomicStructure.checkCollision(resi, resj, self.fineCut)
                    if bad:
                        badRes.append(resi)
                        badRes.append(resj)
                        badCount+=bad
        badPieces = []
        for res in badRes:
            pieceID = self.piecesObject.res2piece[res]
            if pieceID not in badPieces: badPieces.append(pieceID)
        return (badPieces, badCount)
    
    def writeFile(self, lines):
        outFile = open('%s-%i.pdb' % (self.outName, self.i), 'w')
        outFile.writelines(lines)
        outFile.close()

    def makeLastLine(self):
        lastline = '#'
        for p in self.NEWchoice.keys():
            lastline += '%i ' % self.NEWchoice[p]
        lastline+='\n# %i\n' % self.numSteps
        self.logFile.write(lastline)

    def updateLog(self, steps, count, accept):
        self.logFile.write('%i %i %i\n' % (steps, count, accept))
        self.logFile.flush()

    def stepForward(self):

        if self.OLDbadCount<=self.goal:
            totalTime = time.time()-self.t0
            print 'Found a good one! (<=%i)' % self.goal
            print 'this took %i steps and %.2f seconds' % (self.numSteps, totalTime)
            lines = ['REMARK SOLUTION %i steps %.2f seconds\n' % (self.numSteps, totalTime)]
            lines+=self.OLDlines
            self.writeFile(lines)
            self.updateLog(self.numSteps, self.OLDbadCount, 1)    
            if self.numSteps==1:
                self.NEWchoice=self.OLDchoice
            self.makeLastLine()
            return 1

        # If solution is not found yet, continue searching

        chosenPiece = random.choice(self.OLDbadPieces)
        chosenFragment = random.choice(self.piecesObject.piecesLib[chosenPiece].keys())
        self.NEWchoice = copy.copy(self.OLDchoice)
        self.NEWchoice[chosenPiece]=chosenFragment
        self.NEWlines = self.assemblePieces(self.NEWchoice)
        (self.NEWbadPieces, self.NEWbadCount) = self.checkCollisions(self.NEWlines, self.centerAtoms)
        if self.NEWbadCount<self.OLDbadCount: Paccept = 1.0
        else:
            Paccept = 0.5*math.exp(-1.0*float(self.NEWbadCount-self.OLDbadCount)/float(self.OLDbadCount))
        Prandom = random.random()
        if Prandom<=Paccept: accept=1
        else: accept=0

        if accept:
            self.OLDlines = self.NEWlines
            self.OLDchoice = copy.copy(self.NEWchoice)
            self.OLDbadCount = self.NEWbadCount
            self.OLDbadPieces = self.NEWbadPieces
            if self.NEWbadCount<self.minBadCount:
                self.bestLines = self.NEWlines
                self.minBadCount = self.NEWbadCount
        #print "step %i, accept %i, bad %i" % (self.numSteps, accept, self.OLDbadCount)
        return 0
