#!/usr/local/bin/python

import os, sys
import math
import getopt
import glob
import time
import sha
import socket
import random

import LinearAlgebra

import nast.deltaTime as deltaTime
import nast.gomodelRNAmaj2 as gomodel
import nast.md.gosim as gosim
import nast.version
import nast.geom as geom
import nast.pdbFile as pdbFile

f = os.popen('ps -efa')
s = f.read()+'\n'
s += socket.gethostname()+'\n'
s += str(time.time())+'\n'
s += str(os.getpid())+'\n'
seed = int(sha.new(s).hexdigest()[-7:], 16)

R=3.31/1000 # kJ/K*mol
T=300 # K
w=0.2

class Structure:
    def __init__(self, params):
        self.params = params
        self.ignoreList=[]
        if self.params.ignoreList:
            for line in open(self.params.ignoreList).readlines():
                self.ignoreList.append(int(line))
    def processPrimary(self, pSeq):
        self.sequence = open(pSeq).readlines()
        self.seqLength = len(self.sequence)

    def calcIdealE(self):
        self.ENHbonds = geom.calcBestBondE(self.params.B_K, self.params.B_R)
        self.EHbonds = geom.calcBestBondE(self.params.Bh_K, self.params.Bh_R)
        self.ENHangles = geom.calcBestAngleE(self.params.A_K, self.params.A_R)
        self.EH1angles = geom.calcBestAngleE(self.params.Ah_K, self.params.Ah_R)
        self.EH2angles = geom.calcBestAngleE(self.params.Ah2_K, self.params.Ah2_R)
        self.ENHdihedrals = geom.calcBestDiE(self.params.D_K1, self.params.D_K2, self.params.D_K3, self.params.D_R)
        self.EH1dihedrals = geom.calcBestDiE(self.params.Dh_K1, self.params.Dh_K2, self.params.Dh_K3, self.params.Dh_R)
        self.EH2dihedrals = geom.calcBestDiE(self.params.Dh1_K1, self.params.Dh1_K2, self.params.Dh1_K3, self.params.Dh1_R)
        self.EH3dihedrals = geom.calcBestDiE(self.params.Dh2_K1, self.params.Dh2_K2, self.params.Dh2_K3, self.params.Dh2_R)

    def processKnown(self, defFile, structureFile):
        defInput = open(defFile, 'r').readlines()
        self.knownStructure = pdbFile.Pdb(structureFile)
        groupList = []
        group = []
        resList = []
        for line in defInput:
            if line[0]=='#':
                groupList.append(group)
                group = []
            else:
                mark = line.find(':')
                start = int(line[:mark])
                end = int(line[mark+1:])
                for i in range(start, end+1):
                    group.append(i)
                    if i not in resList:
                        resList.append(i)
        group = []
        self.ksList = []
        for group in groupList:
            for i in range(len(group)):
                for j in range(i+1,len(group)):
                    distance = self.knownStructure.getDist(i, j)
                    self.ksList.append((group[i],group[j],distance))
    
    def newStructure(self, type):
        if type=='5to3':
            firstResNum = 0
        if type=='3to5':
            firstResNum = self.seqLength-1
        if type=='RAND':
            firstResNum = random.choice(range(self.seqLength))
        firstResID = self.sequence[firstResNum].strip()
        (x,y,z) = (0.0, 0.0, 0.0)
        line = 'ATOM %6d  C3*   %s   %3d    %8.3f%8.3f%8.3f\n' % \
                  (firstResNum+1,firstResID,firstResNum+1,x,y,z)
        structure = []
        structure.append(line)
        if firstResNum+1<self.seqLength:
            secondResNum = firstResNum+1
        else:
            secondResNum = firstResNum-1
        secondResID = self.sequence[secondResNum].strip()
        (x,y,z) = (5.76, 0.0, 0.0)
        line = 'ATOM %6d  C3*   %s   %3d    %8.3f%8.3f%8.3f\n' % \
                  (secondResNum+1,secondResID,secondResNum+1,x,y,z)

        structure.append(line)
        structure.append('END\n')
        self.startStructure = structure

    def processMissing(self):
        startList = []
        for line in self.startStructure:
            if line[0:4] == 'ATOM':
                res = int(line[23:26])
                startList.append(res)
        endPoints = []
        last = -1
        start = 0
        for i in range(1, self.seqLength+1):
            if i not in startList:
                 if i == 1:
                     start = i
                 elif i != last+1:
                     end = last
                     if start:
                         endPoints.append( (start, end) )
                     start = i
                 last = i
        if last == self.seqLength:
            endPoints.append( (start, last) )
        else:
            endPoints.append( ( start, last) )
        print 'COMMENT endPoints ', endPoints

        missing = []
        self.missingList = []
        for start, end in endPoints:
            if start == 1:
                dir = -1
            elif end == self.seqLength:
                dir = 1
            else:
                dir = random.choice([-1,1])
            if dir==-1:
                temp = end
                end = start
                start = temp
                end-=1
            else:
                end+=1
            missing.extend(range(start,end,dir))

        lastRes=-1
        thisList=[]
        for res in missing:
            if res == (lastRes+1) or res == (lastRes-1):
                temp=0
            elif thisList:
                self.missingList.append(thisList)
                thisList = []

            if res not in self.ignoreList:
                thisList.append(res)
            lastRes = res
        self.missingList.append(thisList)
        print "COMMENT missingList"
        print self.missingList
            
    def processTertiary(self, PT):
        self.ptList = self.parsePredictedTertiary(PT)

    def parsePredictedTertiary(self,defFile):
        defInput = open(defFile, 'r').readlines()
        distList = []
        for line in defInput:
            if line[0]=='#': continue
            cols = line.split()
            res1 = int(cols[0])
            res2 = int(cols[1])
            dist = float(cols[2])
            strength = float(cols[3])
            distList.append( (res1,res2,dist,strength))
        return distList

    def setupStart(self):
        self.eqRun = 0
        self.stepNum = 0
        self.res2index={}
        self.index2res={}
        self.res2pos={}
        self.res2name={}
        self.parseStart(self.stepNum, self.startStructure)

    def findNearby(self, lastRes):
        nearby=[]
        for res in self.res2pos[self.stepNum].keys():
            distance = geom.calcDist(self.res2pos[self.stepNum][lastRes],self.res2pos[self.stepNum][res])
            if distance < 9.0:
                print "distance: ", distance
                nearby.append(res)
        return nearby

    def addResidue(self, res):
        direction="none"
        if res-1 in self.res2pos[self.stepNum].keys():
            direction="pos"
            lastRes = res-1
        if res+1 in self.res2pos[self.stepNum].keys():
            direction="neg"
            lastRes = res+1
        if direction=="none":
            print 'COMMENT Error: No last residue to tack on to ', res
            sys.exit(3)
        self.stepNum+=1
        self.res2pos[self.stepNum] = self.res2pos[self.stepNum-1].copy()


#        if direction=="neg" and direction=="pos":
        temp=0
        if temp:
            res0 = self.res2pos[self.stepNum][res-1]
            res2 = self.res2pos[self.stepNum][res+1]
            (xd,yd,zd) = geom.vecSub(res0, res2)
            (x0,y0,z0) = res0
            (x2,y2,z2) = res2
            newx = x0+(xd/2)
            newy = y0+(yd/2)
            newz = z0+(zd/2)
        else:
            (lastx, lasty, lastz) = self.res2pos[self.stepNum][lastRes]
            R = self.params.B_R*10.0
            zList = geom.arange(-1*R, R,R/20.0)
            phiList = geom.arange(0, 2*math.pi, 0.3)
            allPos = []
            for z in zList:
                theta = math.asin(z/R)
                for phi in phiList:
                    x = R*math.cos(theta)*math.cos(phi)
                    y = R*math.cos(theta)*math.sin(phi)
                    newx = lastx+x
                    newy = lasty+y
                    newz = lastz+z
                    allPos.append( (newx,newy,newz) )
            resNearby = self.findNearby(lastRes)
            del resNearby[resNearby.index(lastRes)]
            print "COMMENT these residues are closeby"
            print resNearby
            goodPos = []
            for pos in allPos:
                good = 1
                tempResi = 0
                while good:
                    if tempResi>=len(resNearby): break
                    tempRes = resNearby[tempResi]
                    distance = geom.calcDist(pos,self.res2pos[self.stepNum][tempRes])
                    if distance < 7.0:
                        good=0
                    tempResi+=1
                if good: goodPos.append(pos)
            print "Of ", len(allPos), ", ", len(goodPos), "are left"
            if goodPos: self.res2pos[self.stepNum][res] = random.choice(goodPos)
            else:
                print "No good position for ", res
                sys.exit(1)
            self.setNastID()

    def setNastID(self):
        list = self.res2pos[self.stepNum].keys()
        list.sort()
        self.res2index[self.stepNum] = {}
        self.index2res[self.stepNum] = {}
        self.res2name[self.stepNum] = {}
        index = 0
        for res in list:
            index+=1
            self.res2index[self.stepNum][res] = index
            self.index2res[self.stepNum][index] = res
            self.res2name[self.stepNum][res] = self.sequence[index-1].strip()

    def parseStart(self, stepNum, structure):
        self.res2index[stepNum] = {}
        self.index2res[stepNum] ={}
        self.res2pos[stepNum] = {}
        self.res2name[stepNum] = {}
        index = 0
        for line in structure:
            if line[0:4]=='ATOM':
                res = int(line[23:26])
                index+=1
                self.res2index[stepNum][res] = index
                self.index2res[stepNum][index] = res
#                resID = line[19:20]
                resName = line[19:20]
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                self.res2name[stepNum][res] = resName
                self.res2pos[stepNum][res] = (x, y, z)

    def parseFile(self, stepNum, structure):
        self.res2index[stepNum] = {}
        self.index2res[stepNum] ={}
        self.res2pos[stepNum] = {}
        self.res2name[stepNum] = {}
        for line in structure:
            if line[0:4]=='ATOM':
                index = int(line[23:26])
                res = self.index2res[stepNum-1][index]
                self.index2res[stepNum][index] = res
                self.res2index[stepNum][res] = index
                name = line[19:20]
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                self.res2name[stepNum][res] = name
                self.res2pos[stepNum][res] = (x, y, z)

    def printFile(self, stepNum, name):
        output = open(name, 'w')
        list = self.res2pos[stepNum].keys()
        list.sort()
        for res in list:
            index=self.res2index[stepNum][res]
            (x,y,z) = self.res2pos[stepNum][res]
            name = self.res2name[stepNum][res]
            if index>1:
                lastRes=self.index2res[stepNum][index-1]
                if res != (lastRes+1):
                    output.write('TER\n')
            line = 'ATOM %6d  C3*   %s   %3d    %8.3f%8.3f%8.3f\n' % (res,name,res,x,y,z)
            output.write(line)
        output.write('END\n')
        output.close()

    def setKS(self):
        resList = self.res2pos[self.stepNum].keys()
        for iRes, jRes, dist in self.ksList:
            if iRes in resList and jRes in resList:
                if dist < 20.0:
                    self.gModel.addArbitrary(self.res2index[self.stepNum][iRes]-1, self.res2index[self.stepNum][jRes]-1, self.params.kS_K, dist/10.0, 0, 2, 0)

    def setPT(self):
        resList = self.res2pos[self.stepNum].keys()
        for iRes, jRes, dist, strength in self.ptList:
            if iRes in resList and jRes in resList:
                self.gModel.addArbitrary(self.res2index[self.stepNum][iRes]-1, self.res2index[self.stepNum][jRes]-1, strength, dist/10.0, 0, 2, 0)
        
    def addOneRes(self):
        if len(self.missingLong) > self.i:
            self.thisAdd.append( (self.i, self.missingLong[self.i]) )
        if len(self.thisAdd)==0:
            print 'COMMENT no more residues to add'
            self.printFile(self.stepNum-1, '%s.pdb' % self.outputFilename)
            self.gModel.makePSF('%s.psf' % self.outputFilename)
            self.keepAdding = 0
            sys.exit(0)
        self.i+=1

    def addSeveralRes(self):
        for list in self.missingList:
            if len(list) > self.i:
                self.thisAdd.append( (self.i, list[self.i]) )
        if len(self.thisAdd)==0:
            print 'COMMENT no more residues to add'
            self.printFile(self.stepNum-1, '%s.pdb' % self.outputFilename)
            self.gModel.makePSF('%s.psf' % self.outputFilename)
            self.keepAdding = 0
            sys.exit(0)
        self.i+=1

    def setgModel(self):
        self.gModel = gomodel.Gomodel(self.pdbInitFilename, 
                     B_K=self.params.B_K, B_R=self.params.B_R,
                     Bh_K=self.params.Bh_K, Bh_R=self.params.Bh_R,
                     A_K=self.params.A_K, A_R=self.params.A_R,
                     Ah_K=self.params.Ah_K, Ah_R=self.params.Ah_R,
                     Ah2_K=self.params.Ah2_K, Ah2_R=self.params.Ah2_R,
                     D_K1=self.params.D_K1, D_K2=self.params.D_K2, D_K3=self.params.D_K3, D_R=self.params.D_R,
                     Dh_K1=self.params.Dh_K1, Dh_K2=self.params.Dh_K2, Dh_K3=self.params.Dh_K3, Dh_R=self.params.Dh_R,
                     Dh1_K1=self.params.Dh1_K1, Dh1_K2=self.params.Dh1_K2, Dh1_K3=self.params.Dh1_K3, Dh1_R=self.params.Dh1_R,
                     Dh2_K1=self.params.Dh2_K1, Dh2_K2=self.params.Dh2_K2, Dh2_K3=self.params.Dh2_K3, Dh2_R=self.params.Dh2_R,
                     eR12=self.params.eR12, HsR12=self.params.HsR12, NHsR12=self.params.NHsR12, eps=self.params.eps,
                     helixDefFilename=self.params.predictedSecondary,
                     centerAtomName=self.params.centerAtomName,
                     mass=self.params.massDefault)

    def setgSim(self):
        self.pArrays = self.gModel.buildParamarrays()
        self.gSim = gosim.Gosim(self.pArrays, dt=self.params.dt, seed=self.params.seed)
        self.gSim.printSimInfo()
        self.gSim.setMaxPotEnergy(999999)
        self.gSim.setTemp(self.params.temp0)

    def rungSim(self):
        self.ranWithNoErrors = self.gSim.doSim(self.params.numSteps,
                                     numStepsVV=10,
                                     numStepsTemp=self.params.numStepsTemp,
                                     numStepsPrint=self.params.numStepsPrint,
                                     printEPArray=self.params.printEPArray)

    def calcEresidual(self):
        coords = self.gSim.getCoords()
        (NHbonds, Hbonds, NHangles, H1angles, H2angles, NHdihedrals, \
                H1dihedrals, H2dihedrals, H3dihedrals, \
                NHvdw, Hvdw) = self.gModel.calcEcoord(coords)
        residual=0.0
        if NHbonds: residual+=math.fabs(NHbonds-self.ENHbonds)/math.fabs(self.ENHbonds)
        if Hbonds: residual+=math.fabs(Hbonds-self.EHbonds)/math.fabs(self.EHbonds)
        if NHangles: residual+=math.fabs(NHangles-self.ENHangles)/math.fabs(self.ENHangles)
        if H1angles: residual+=math.fabs(H1angles-self.EH1angles)/math.fabs(self.EH1angles)
        if H2angles: residual+=math.fabs(H2angles-self.EH2angles)/math.fabs(self.EH2angles)
        if NHdihedrals: residual+=math.fabs(NHdihedrals-self.ENHdihedrals)/math.fabs(self.ENHdihedrals)
        if H1dihedrals: residual+=math.fabs(H1dihedrals-self.EH1dihedrals)/math.fabs(self.EH1dihedrals)
        if H2dihedrals: residual+=math.fabs(H2dihedrals-self.EH2dihedrals)/math.fabs(self.EH2dihedrals)
        if H3dihedrals: residual+=math.fabs(H3dihedrals-self.EH3dihedrals)/math.fabs(self.EH3dihedrals)
        if NHvdw: residual+=NHvdw
        if Hvdw: residual+=Hvdw
        return residual

    def equilibrate(self):
        print 'COMMENT equilibrating'
        self.calcIdealE()
        self.eqRun+=1
        startName = 'start_%i.pdb' % self.stepNum
        self.printFile(self.stepNum, startName)
        self.pdbInitFilename = startName
#        self.setTempPS()
        try:
            self.setgModel()
        except ValueError:
            print 'COMMENT had a ValueError'
            self.ranWithNoErrors = 0
            print 'COMMENT stepNum: ', self.stepNum
            sys.exit(1)
            return
        if self.ksList:
            self.setKS()
        if self.ptList:
            self.setPT()

        self.gModel.makePSF('start_%i.psf' % self.stepNum)
        self.setgSim()

        self.startOver = 0
        print 'COMMENT Starting Residual ', self.calcEresidual()
        for run in range(10):
            for iRun in range(10):
                try:
                    self.rungSim()
                except LinearAlgebra.LinAlgError:
                    print 'COMMENT Linear Algebra Error'
                    self.startOver = 1
                    self.ranWithNoErrors = 0
                    break
                self.gSim.writePDBFile('trace_%i.pdb' % self.stepNum, append=True)
            energy = self.gSim.eForce()
            if energy > 10000:
                print 'COMMENT energy went crazy'
                self.ranWithNoErrors=0
                self.startOver = 1
                break
            residual = self.calcEresidual()
            print 'COMMENT Residual: ', residual
            print
            self.gModel.printValStats()
            print
            self.gModel.printEStats()
            print
            if residual < 20.0:
                print 'COMMENT Done minimizing'
                break

        if self.ranWithNoErrors:
            self.gSim.writePDBFile('last.pdb', append=False)
        else:
            self.startOver = 1


