#!/usr/local/bin/python

# USAGE #
# python calcTheseDistances.py input.pdb PTfile.txt outputFile.txt

import sys, math

def calcDistance(point1, point2):
    (x1,y1,z1) = point1
    (x2,y2,z2) = point2
    distance2 = (x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2
    distance = distance2**.5
    return distance

def calcDistList(position, posList):
    distList = []
    pos1List = []
    pos2List = []
    for pair in posList:
        (pos1,pos2) = pair
        distance = calcDistance(position[pos1],position[pos2])
        distList.append(distance)
        pos1List.append(pos1)
        pos2List.append(pos2)
    return (distList, pos1List, pos2List)

def getPos(input):
    position = {}
    numAtoms = 0
    for line in input:
        cols = line.split()
        if line[0:4] == 'ATOM':
            atom = int(line[23:26]) # 4 for coarse-grained (1 also works for both)
            x = float(line[30:38])
            y = float(line[38:46])
            z = float(line[46:54])
            position[atom]=(x,y,z)
    return position

def makePosList(atomsList):
    posList = []
    for i in range(len(atomsList)):
        for j in range(i+1, len(atomsList)):
            posList.append((atomsList[i],atomsList[j]));
    return posList

def calcRg(input):
    atomsList = []
    for line in input:
        if line[0:4]=="ATOM":
            cols=line.split()
            atomsList.append(int(cols[1]))
    posList = makePosList(atomsList)
    position_D = getPos(input)
    distList = calcDistList(position_D, posList)
    sum = 0.0
    for i in distList:
        sum = sum+i
    mean = sum/float(len(distList))
    return mean
    

if __name__=='__main__':
    inFile = open(sys.argv[1])
    distData = open(sys.argv[2]).readlines()
    output = open(sys.argv[3], 'w')
    cutoff = float(sys.argv[4])

    posList=[]
    for line in distData:
        if line[0] != "#":
            cols = line.split()
            posList.append( (int(cols[0]), int(cols[1])) )

    line = inFile.readline()
    frame = []
    frame.append(line)
    count=0
    while line:
        if line[0:3]=='END':
            count+=1
            position_D = getPos(frame)
            (distList, pos1List, pos2List) = calcDistList(position_D, posList)
            keep = 1
            for i in range(len(distList)):
                if distList[i] > cutoff:
                    keep = 0
            if keep:
                output.writelines(frame)
            frame = []
        line = inFile.readline()
        frame.append(line)

    output.close()
