#!/usr/bin/env python
# encoding: utf-8
"""
create_network.py
Transform a prediction csv file into a simple cytoscape network.

Created by Nicholas Tatonetti on 2010-05-13.
Copyright (c) 2010 Stanford University. All rights reserved.
"""

import os
import csv
import sys
import numpy
import getopt


help_message = """
USAGE
=====

python create_network.py -i prediction_file.csv -o network_file.net -e [edge_type]
"""

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def create_network(input_file, output_file, attr_file, edge_name, score_threshold, normalize, include_known):
    """
    Parses the csv file and outputs the result to the output file in simple network format.
    eg.
    node1 edge_name node2
    ...
    """
    reader = csv.reader(open(input_file))
    headers = reader.next()
    drug_names_index = headers.index("RowID")
    training_label_index = headers.index("Pair_Label")
    score_index = headers.index("Pair_Score")
    
    input_rows = [row for row in reader]
    
    if normalize:
        scores = [float(row[score_index]) for row in input_rows]
        mu = numpy.mean(scores)
        sd = numpy.std(scores)
        score_threshold = (score_threshold - mu) / sd
    
    output_fh = open(output_file, 'w')
    attr_fh = open(attr_file, 'w')
    print >> attr_fh, "NormalizedRegressionScore"
    
    for row in input_rows:
        drug1, drug2 = row[drug_names_index].split(',') # drug pair is stored as DRUG1,DRUG2
        drug1 = drug1.replace(' ','_').replace("'","").replace("/","").replace("\\","").replace('(',"").replace(")","").replace("-","")
        drug2 = drug2.replace(' ','_').replace("'","").replace("/","").replace("\\","").replace('(',"").replace(")","").replace("-","")
        
        score = float(row[score_index])
        if normalize:
            score = (score - mu) / sd
        
        is_known = True if float(row[training_label_index]) == 1.0 else False
        
        if (score >= score_threshold) and (include_known or not is_known):
            print >> output_fh, "%s %s %s" % (drug1, edge_name, drug2)
            print >> attr_fh, "%s (%s) %s = %f" % (drug1, edge_name, drug2, score)
    
    output_fh.close()
    attr_fh.close()

def main(argv=None):
    edge_type = None
    input_file = None
    output_file = None
    attr_file = None
    score_threshold = 0.0
    
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "ho:ve:i:s:a:", ["help", "output=", "attr-file=", "edge-type=", "input="])
        except getopt.error, msg:
            raise Usage(msg)
    
        # option processing
        for option, value in opts:
            if option == "-v":
                verbose = True
            if option in ("-h", "--help"):
                raise Usage(help_message)
            if option in ("-o", "--output"):
                output_file = value
            if option in ("-i", "--input"):
                input_file = value
            if option in ("-a", "--attr-file"):
                attr_file = value
            if option in ("-e", "--edge-type"):
                edge_type = value
            if option in ("-s",):
                score_threshold = float(value)
        
        if edge_type is None:
            raise Usage("Edge type parameter is required.")
        
        if input_file is None:
            raise Usage("Input file parameter is required.")
        
        if output_file is None:
            output_file = "%s.sif" % "".join(input_file.split('.')[:-1])
        
        if attr_file is None:
            attr_file = "%s.attr" % "".join(input_file.split('.')[:-1])
    
    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        print >> sys.stderr, "\t for help use --help"
        return 2
    
    create_network(input_file, output_file, attr_file, edge_type, score_threshold, normalize=True, include_known=False)


if __name__ == "__main__":
    sys.exit(main())
