#!/usr/bin/env python

import os
import sys
import re
import math
import getopt


def readHits(filename):
    try:
        infile = open(filename,"r")
    except:
        print "Couldn't open hits file (%s) for reading" % filename
        sys.exit(-1)

    hits = []
    
    line=infile.readline()
    while line:
        fields = line.split()
        score = float(fields[1])
        coord = map(float,fields[2:5])
        hits.append((score,coord))
        line=infile.readline()
    infile.close()
    return hits

def readBackgroundHits(filename):
    infile=open(filename,"r")
    line=infile.readline()
    scores = []
    while line:
        fields = line.split()
        scores.append(float(fields[1]))
        line=infile.readline()
    infile.close()

    mean = sum(scores)/float(len(scores))
    stdev = 0
    for score in scores: stdev += (mean-score)**2
    stdev/=float(len(scores))
    stdev = math.sqrt(stdev)
    zscores = [(score-mean)/stdev for score in scores]

    return zscores
        

def printHeaders(outfile,name,color):
    outfile.write("""
from pymol import cmd
from pymol.cgo import *
from pymol import viewing
from pymol import querying
color = cmd._cmd.get_color(\"%(color)s\",0)

%(name)s = []
""" % vars())
    

def printSpheres2(objs,name,color,radius,outfile):

    outfile.write("%(name)s += [COLOR,color[0],color[1],color[2]]\n" % vars())
    count = 0
    for x,y,z in objs:
        count+=1
        if count==10000:
            print "Stopped at %(count)i obj3" % vars()
            #break
        outfile.write("%(name)s += [SPHERE, %(x).3f,%(y).3f,%(z).3f,%(radius)f]\n" % vars())
    outfile.write("\n")


def printSpheres(objs1,objs2,objs3,outfile):
    outfile.write("objs1 += [COLOR,color1[0],color1[1],color1[2]]\n")
    count = 0
    for x,y,z in objs1:
        count+=1
        if count==10000:
            print "Stopped at %(count)i obj1" % vars()
            break
        #outfile.write("objs1 += [SPHERE, %(x).3f,%(y).3f,%(z).3f,0.3]\n" % vars())
        outfile.write("objs2 += [SPHERE, %(x).3f,%(y).3f,%(z).3f,0.3]\n" % vars())
    outfile.write("\n")

    count = 0
    outfile.write("objs2 += [COLOR,color2[0],color2[1],color2[2]]\n")
    for x,y,z in objs2:
        count+=1
        if count==10000:
            print "Stopped at %(count)i obj2" % vars()
            break
        outfile.write("objs2 += [SPHERE, %(x).3f,%(y).3f,%(z).3f,0.3]\n" % vars())
    outfile.write("\n")

    outfile.write("objs3 += [COLOR,color3[0],color3[1],color3[2]]\n")
    count = 0
    for x,y,z in objs3:
        count+=1
        if count==10000:
            print "Stopped at %(count)i obj3" % vars()
            break
        outfile.write("objs3 += [SPHERE, %(x).3f,%(y).3f,%(z).3f,0.3]\n" % vars())
    outfile.write("\n")

def displayHits(outfile,name,obj_name):
    outfile.write("""
obj_name=\"%(name)s\"
cur_view=viewing.get_view(0)
cmd.delete(obj_name)
cmd.load_cgo(%(obj_name)s,obj_name)
viewing.set_view(cur_view)
""" % vars())
    

def getNumberOfHits(filename):
    handle = os.popen("wc -l %(filename)s" % vars(),"r")
    line=handle.readline()
    handle.close()
    return int(line.split()[0])


def readDensityFile(filenames):
    densities = []
    for filename in filenames:
        infile = open(filename,"r")
        densities.append([float(line.split()[-1]) for line in infile.readlines()])
        infile.close()

    model_density_stats = []
    for i in range(len(densities)):
        mean = sum(densities[i])/float(len(densities[i]))
        stdev = 0
        for d in densities[i]: stdev += (mean-d)**2
        stdev/=float(len(densities[i]))
        stdev = math.sqrt(stdev)
        model_density_stats.append((mean,stdev))
    return model_density_stats

def printUsage():
    print """
    Usage: viewhits.py -i <hits file> -s <score> 
             [-r <sphere radius> -o <outfile> -n <pymol object name>
              -c <sphere colod> -h]

    Arguments:

    -i,--infile     Hits filename (required)
    -s,--score  	Score cutoff (required)
    -r,--radius	Sphere radius (optional; defaults to 0.3)
    -o,--outfile    Output filename (optional, defaults to 'view.py')
    -n,--name	Name of pymol object (optional; defaults to 'hits')
    -c,--color	Color of pymol spheres (optional; defaults to red)
    -h,--help       Print this message
    """

def main(arguments):
    try:
        optlist,args = getopt.getopt(arguments,"i:r:s:o:n:c:h",["infile=","radius=","score=","outfile=","name=","color=","help"])
    except:
        print "Couldn't parse command line arguments\n"
        printUsage()
        sys.exit(-1)

    infilename = None
    outfilename = None
    name = "hits"
    color = "red"
    score_cutoff = None
    radius = 0.3
    for opt,arg in optlist:
        if opt in ["-i","--infile"]: infilename = arg
        elif opt in ["-r","--radius"]:
            try:
                radius = float(arg)
            except:
                print "Expected float for radius argument"
                sys.exit(0)
        elif opt in ["-n","--name"]: name=arg
        elif opt in ["-c","--color"]: color=arg
        elif opt in ["-o","--outfile"]:
            outfilename = arg
        elif opt in ["-d","--density"]:
            density_filenames=re.split(",",arg)
        elif opt in ["-t","--tolerance"]:
            density_tolerance = float(arg)
        elif opt in ["-s","--scores"]:
            try:
                score_cutoff = float(arg)
            except:
                print "Expected float for score cutoff argument"
                sys.exit(-1)
        elif opt in ["-h","--help"]:
            printUsage()
            sys.exit(0)
        else:
            print "Didn't understand command line argument:",opt,arg,"\n"
            printUsage()
            sys.exit(-1)

    if infilename==None:
        print "Specify hits filename with -i option\n"
        printUsage()
        sys.exit(-1)

    if outfilename==None:
        outfilename = "view.py"

    if score_cutoff==None:
        print "Specify score cutoff using -s command line option\n"
        printUsage()
        sys.exit(-1)

    hits = readHits(infilename)
    num_hits = len(hits)





    outfile = open(outfilename,"w")

    printHeaders(outfile,name,color)

    objs = []
    maximum = -100000

    for score,coord in hits:
        if score>maximum: maximum=score 
        if score >= score_cutoff:
            objs.append(coord)
       
    printSpheres2(objs,name,color,radius,outfile)

    if(len(objs)>0):
        displayHits(outfile,name,name)

main(sys.argv[1:])
              
