#!/usr/local/bin/python

import sys, math

# calculate the distance between two points r1 and r2
def distance(r1,r2):
    (x1,y1,z1) = r1
    (x2,y2,z2) = r2
    distance = math.sqrt((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1))
    return distance

# calc the length of a vector v
def vecLength(v):
    (x,y,z) = v
    value = distance(v,(0,0,0))
    return value

# calculate the position of the next residue based on past three
def calcPosition(r1, r2, r3, bond, angle, dihedral):

# do dihedral first, rotate a around b
    a = vecSub(r2,r1)
    a = unitVec(a)
    b = vecSub(r3,r2)
    b = unitVec(b)
    Mb = M_v_theta(b,math.pi-dihedral)
    c = vMdotProduct(a, Mb)

# rotate b around cross(b,c)
    p = crossProduct(b,c)
    p = unitVec(p)
    Mp = M_v_theta(p,math.pi-angle)
    d = vMdotProduct(b,Mp)
    d = setVecLength(d, bond)
    r4 = vecAdd(r3,d)
    return r4
    
# resize the vector to have the length of a bond
def setVecLength(v, bond):
    (x,y,z) = v
    len0 = vecLength(v)
    x1 = (x/len0)*bond
    y1 = (y/len0)*bond
    z1 = (z/len0)*bond
    return (x1,y1,z1)

# resize to unit vector
def unitVec(v):
    (x,y,z) = v
    len0 = vecLength(v)
    x1 = x/len0
    y1 = y/len0
    z1 = z/len0
    return (x1,y1,z1)

# calculate the cross product of r1 and r2
def crossProduct(r1,r2):
    (x1,y1,z1) = r1
    (x2,y2,z2) = r2
    x3 = y1*z2 - z1*y2
    y3 = z1*x2 - x1*z2
    z3 = x1*y2 - y1*x2
    return (x3,y3,z3)

# calculate the dot product of two vectors r1 and r2
def vvdotProduct(r1,r2):
    (x1,y1,z1) = r1
    (x2,y2,z2) = r2
    value = x1*x2+y1*y2+z1*z2
    return value

# calculate the dot product of a vector v and a matrix M
def vMdotProduct(v,M):
    (x1,y1,z1) = v
    (M11,M12,M13) = M[0]
    (M21,M22,M23) = M[1]
    (M31,M32,M33) = M[2]
    x2 = vvdotProduct((M11,M21,M31),v)
    y2 = vvdotProduct((M12,M22,M32),v)
    z2 = vvdotProduct((M13,M23,M33),v)
    return (x2,y2,z2)

# calculate the rotation matrix M(v,theta)
def M_v_theta(v,theta):
    (x,y,z) = v
    M11 = math.cos(theta)+(1-math.cos(theta))*x*x
    M12 = (1-math.cos(theta))*x*y-math.sin(theta)*z
    M13 = (1-math.cos(theta))*x*z+math.sin(theta)*y
    M21 = (1-math.cos(theta))*y*x+math.sin(theta)*z
    M22 = math.cos(theta)+(1-math.cos(theta))*y*y
    M23 = (1-math.cos(theta))*y*z-math.sin(theta)*x
    M31 = (1-math.cos(theta))*z*x-math.sin(theta)*y
    M32 = (1-math.cos(theta))*z*y+math.sin(theta)*x
    M33 = math.cos(theta)+(1-math.cos(theta))*z*z
    M = [(M11,M12,M13),(M21,M22,M23),(M31,M32,M33)]
    return M

# calculate v1-v2, return result in form (x,y,z)
def vecSub(v1, v2):
    (xv1,yv1,zv1) = v1
    (xv2,yv2,zv2) = v2
    xv3 = xv1-xv2
    yv3 = yv1-yv2
    zv3 = zv1-zv2
    return (xv3,yv3,zv3)

# calculate sum v1+v3, return v3
def vecAdd(v1,v2):
    (xv1,yv1,zv1) = v1
    (xv2,yv2,zv2) = v2
    xv3 = xv1+xv2
    yv3 = yv1+yv2
    zv3 = zv1+zv2
    return (xv3,yv3,zv3)

# set the third residue at angle in plane xy
def setR3(r1,r2, bond, angle):
    v1 = (1,1,0)
    v1 = unitVec(v1)
    v2 = vecSub(r1,r2)
    v2 = unitVec(v2)
    p = crossProduct(v2,v1)
    p = unitVec(p)
    Mp = M_v_theta(p,angle)
    c1 = vMdotProduct(v2, Mp)
    c = setVecLength(c1, bond)
    r3 = vecAdd(r2,c)
    return r3

# check the angle between for three points
def calcAngle(r1,r2,r3):
    a = vecSub(r1,r2)
    b = vecSub(r3,r2)
    value = math.acos(vvdotProduct(a,b)/(vecLength(a)*vecLength(b)))
    return value

# check the dihedral given 4 points
def calcDihedral(r1,r2,r3,r4):
    a = vecSub(r2,r1)
    b = vecSub(r3,r2)
    c = vecSub(r4,r3)
    d1 = crossProduct(a,b)
    d2 = crossProduct(c,b)
    dihedral = math.acos(vvdotProduct(d1,d2)/(vecLength(d1)*vecLength(d2)))
    return dihedral

# print out a matrix to the screen
def printMatrix(M):
    for v in M:
        printVector(v)

# print out a vector to the screen
def printVector(v):
    (x,y,z) = v
    print '%f\t%f\t%f' % (x,y,z)

# correct r1 position based on r2, r3, r4
def correctR1(r2,r3,r4, bond, angle, dihedral):
    r1 = calcPosition(r4,r3,r2, bond, angle, dihedral)
    return r1

# unfold
def unfold(residueL, offset, bond, angle, dihedral):

    positionL = []

# set the first position (arbitrary)
    r1 = (0,0,offset)
    positionL.append(r1)

# set the second position (arbitrary: +bond in x-dir)
    r2 = (bond,0,offset)
    positionL.append(r2)

    numR = len(residueL)

# set the third position (arbitrary: in xy plane)
    if numR > 2:
        r3 = setR3(r1,r2, bond, angle)
        positionL.append(r3)

    if numR > 3:
# calculate all the positions
        for r in range(3,numR): #numR
            p1 = positionL[r-3]
            p2 = positionL[r-2]
            p3 = positionL[r-1]
#        print 'angle: %f' % calcAngle(p1,p2,p3)
            p4 = calcPosition(p1,p2,p3, bond, angle, dihedral)
#        print 'dihedral: %f' % calcDihedral(p1,p2,p3,p4)
            positionL.append(p4)

# correct position of r1
        r2 = positionL[1]
        r3 = positionL[2]
        r4 = positionL[3]
        r1 = correctR1(r2,r3,r4, bond, angle, dihedral)
        positionL[0] = r1

    return positionL

def unfoldFile(sequence, params, outputFile):
    input = open(sequence).readlines()
    bond = params.B_R * 10
    angle = params.A_R
    dihedral = params.D_R

    molList = []
    residueL = []
    pdb = 0
    if sequence.find('primary')>=0:
        for line in input:
            cols = line.split()
            if len(cols)>0:
                residueL.append(cols[0])
        molList.append(residueL)
    else:
        pdb = 1
        for line in input:
            if line[0:3]=='TER' or line[0:3]=='END':
                molList.append(residueL)
                residueL=[]
            if line.find('ATOM')==0:
                residueL.append(line[19])
    posLL = []
    offset=0.0
    for list in molList:
        positionL = unfold(list, offset, bond, angle, dihedral)
        posLL.append(positionL)
        offset+=10.0

    lines=''
    count=0
    mol=0
    if pdb==1:
        for line in input:
            if line.find('ATOM')==0:
                (x,y,z) = posLL[mol][count]
                lines+='%s%8.3f%8.3f%8.3f\n' % (line[:30],x,y,z)
                count+=1
            if line[0:3]=='TER' or line[0:3]=='END':
                lines+=line
                mol+=1
                count=0
    else:
        totalCount=0
        for posL in posLL:
            for x,y,z in posL:
                lines+='ATOM %6d  C3*   %s   %3d    %8.3f%8.3f%8.3f\n' % \
                       (totalCount+1,molList[mol][count],totalCount+1,x,y,z)
                count+=1
                totalCount+=1
            mol+=1
            count=0
            lines+='TER\n'
    output = open(outputFile, 'w')
    output.write(lines)
    output.close()


if __name__=='__main__':
    input = sys.argv[1]
    paramModule = sys.argv[2]
    params = __import__(paramModule)
    output = sys.argv[3]
    unfoldFile(input, params, output)

