#!/usr/bin/env python

import os, sys

#Conversion constants
KJ_PER_KCAL = 4.1868
KCAL_PER_KJ = 1.0/KJ_PER_KCAL

A_PER_NM = 10.0
NM_PER_A = 1.0/A_PER_NM


def usageError():
    s0 = 'usage: %s' % os.path.basename(sys.argv[0])
    s0Blanks = len(s0)*' '
    s='%s  amberOutputFilename openmmOutputFilename' % s0
    sys.stdout.write('%s\n' % s)
    sys.exit(1)


def getNextSet(f):
    line = f.readline()
    clean = line.strip()
    while line and clean.find('COORD')!=0 and clean.find('FORCE')!=0:
        line = f.readline()
        clean = line.strip()
    if not line:
        return None
    (type, atomNum, x, y, z) = clean.split()
    return (type, int(atomNum), float(x), float(y), float(z))


try:
    amberFilename=sys.argv[1]
    openmmFilename=sys.argv[2]
except IndexError:
    usageError()

fAmber = open(amberFilename)
fOpenMM = open(openmmFilename)

stepNum=0
atomNumO_old=-1
maxAtomNumForce=-1
maxAtomNumCoord=-1
while True:
    try:
        (typeA, atomNumA, xA, yA, zA) = getNextSet(fAmber)
        if typeA=='COORD' and atomNumA>maxAtomNumCoord:
            maxAtomNumCoord=atomNumA
            continue
        if typeA=='FORCE' and atomNumA>maxAtomNumForce:
            maxAtomNumForce=atomNumA
            continue
        if typeA=='COORD':
            xA*=NM_PER_A
            yA*=NM_PER_A
            zA*=NM_PER_A
        if typeA=='FORCE':
            xA*=KJ_PER_KCAL/NM_PER_A
            yA*=KJ_PER_KCAL/NM_PER_A
            zA*=KJ_PER_KCAL/NM_PER_A
        (typeO, atomNumO, xO, yO, zO) = getNextSet(fOpenMM)
    except TypeError:
        break

    if typeA!=typeO or atomNumA!=atomNumO:
        sys.stdout.write("ERROR: line missmatch:\n%s%s" % (lineA, lineO))
        sys.exit(1)
    
    if atomNumO<atomNumO_old:
        stepNum+=1

    try:
        s = '%3d %s %3d: %17.6f %17.6f %17.6f  %17.6f %17.6f %17.6f  %17.6f %17.6f %17.6f  %17.6f%% %17.6f%% %17.6f%%' \
           % (stepNum//2, typeA, atomNumA,
              xO, yO, zO,
              xA, yA, zA,
              xO/xA, yO/yA, zO/zA,
              100*(xO-xA)/xA, 
              100*(yO-yA)/yA, 
              100*(zO-zA)/zA)
    except ZeroDivisionError:
        s = '%3d %s %3d: %17.6f %17.6f %17.6f  %17.6f %17.6f %17.6f' \
           % (stepNum//2, typeA, atomNumA,
              xO, yO, zO,
              xA, yA, zA)
    sys.stdout.write('%s\n' % s)
    
    atomNumO_old=atomNumO


