#!/usr/local/bin/python

import sys, math


paramModule = sys.argv[2]
params = __import__(paramModule)
bond = params.B_R * 10
angle = params.A_R
dihedral = params.D_R


#print 'bond = %f, angle = %f, dihedral = %f' % (bond, angle, dihedral)

# calculate the distance between two points r1 and r2
def distance(r1,r2):
    (x1,y1,z1) = r1
    (x2,y2,z2) = r2
    distance = math.sqrt((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1))
    return distance

# calc the length of a vector v
def vecLength(v):
    (x,y,z) = v
    value = distance(v,(0,0,0))
    return value

# calculate the position of the next residue based on past three
def calcPosition(r1, r2, angle):

    a = vecSub(r2,r1)
    (x1,y1,z1) = a
    a2 = (x1,y1)
    M11 = math.cos(math.pi-angle)
    M12 = -math.sin(math.pi-angle)
    M21 = math.sin(math.pi-angle)
    M22 = math.cos(math.pi-angle)
    M = [(M11,M12),(M21,M22)]

    b2 = MvdotProduct(M,a2)
    (x2,y2) = b2
    b = (x2,y2,z1+0.01)
    r3 = vecAdd(r2,b)

    return r3
    

# calculate the dot product of a matrix M and a vector v
def MvdotProduct(M,v):
    (x1,y1) = v
    (M11,M12) = M[0]
    (M21,M22) = M[1]
    x2 = M11*x1 + M12*y1
    y2 = M21*x1 + M22*y1
    return (x2,y2)


# calculate v1-v2, return result in form (x,y,z)
def vecSub(v1, v2):
    (xv1,yv1,zv1) = v1
    (xv2,yv2,zv2) = v2
    xv3 = xv1-xv2
    yv3 = yv1-yv2
    zv3 = zv1-zv2
    return (xv3,yv3,zv3)

# calculate sum v1+v3, return v3
def vecAdd(v1,v2):
    (xv1,yv1,zv1) = v1
    (xv2,yv2,zv2) = v2
    xv3 = xv1+xv2
    yv3 = yv1+yv2
    zv3 = zv1+zv2
    return (xv3,yv3,zv3)


# check the angle between for three points
def calcAngle(r1,r2,r3):
    a = vecSub(r1,r2)
    b = vecSub(r3,r2)
    value = math.acos(vvdotProduct(a,b)/(vecLength(a)*vecLength(b)))
    return value

# print out a matrix to the screen
def printMatrix(M):
    for v in M:
        printVector(v)

# print out a vector to the screen
def printVector(v):
    (x,y,z) = v
    #print '%f\t%f\t%f' % (x,y,z)

# unfold
def unfold(residueL, offset):

    angle = (len(residueL)-2)*math.pi/len(residueL)

    #print 'numRes: ', len(residueL)
    #print 'angle: ', angle

    positionL = []

# set the first position (arbitrary)
    r1 = (0,0,offset)
    positionL.append(r1)

# set the second position (arbitrary: +bond in x-dir)
    r2 = (bond,0,offset)
    positionL.append(r2)

    numR = len(residueL)

# set the third position (arbitrary: in xy plane)
    if numR > 2:
        for r in range(2,numR):
            p1 = positionL[r-2]
            p2 = positionL[r-1]
            p3 = calcPosition(p1,p2, angle)
            positionL.append(p3)

    return positionL

def unfoldFile(sequence, outputFile):
    input = open(sequence).readlines()
    molList = []
    residueL = []
    pdb = 0
    if sequence.find('primary')>=0:
        for line in input:
            cols = line.split()
            if len(cols)>0:
                residueL.append(cols[0])
        molList.append(residueL)
    else:
        pdb = 1
        for line in input:
            if line[0:3]=='TER' or line[0:3]=='END':
                molList.append(residueL)
                residueL=[]
            if line.find('ATOM')==0:
                residueL.append(line[19])

    posLL = []
    offset=0.0
    for list in molList:
        positionL = unfold(list, offset)
        posLL.append(positionL)
        offset+=10.0

    lines=''
    count=0
    mol=0
    if pdb==1:
        for line in input:
            if line.find('ATOM')==0:
                (x,y,z) = posLL[mol][count]
                lines+='%s%8.3f%8.3f%8.3f\n' % (line[:30],x,y,z)
                count+=1
            if line[0:3]=='TER' or line[0:3]=='END':
                lines+=line
                mol+=1
                count=0
    else:
        totalCount=0
        for posL in posLL:
            for x,y,z in posL:
                lines+='ATOM %6d  C3*   %s   %3d    %8.3f%8.3f%8.3f\n' % \
                       (totalCount+1,molList[mol][count],totalCount+1,x,y,z)
                count+=1
                totalCount+=1
            mol+=1
            count=0
            lines+='TER\n'

    output = open(outputFile, 'w')
    output.write(lines)
    output.close()


if __name__=='__main__':
    input = sys.argv[1]
    output = sys.argv[3]
    unfoldFile(input, output)

