import os
import csv
import sys
import MySQLdb

from pyweka import MachineLearning as ml
from namedmatrix import NamedMatrix

def save_roc_plots(lr, fn):
    fh_roc = open(os.path.expanduser('~/Dropbox/cocktails_roc_%s_%f.csv' % (fn, lr.results['AUROC'])), 'w')
    roc = csv.writer(fh_roc)
    roc.writerows(zip(lr.plot_data[1], lr.plot_data[0]))
    fh_roc.close()

db = MySQLdb.connect(host="localhost", port=3306, user="root", passwd="dummy_password", db="project_cocktails_pred")
c = db.cursor()

table_prefixes = ['all','chol', 'depr', 'diab', 'heptox', 'htn', 'livdys', 'renimp', 'suic']

for prefix in table_prefixes:
    
    print >> sys.stderr, "Building ROC Curves for %s" % prefix
    query = """
    select *
    from
    (
        select a.stitch_id as stitch_id1, b.stitch_id as stitch_id2, pair_score, label
        from %s_pred_pairs
        join aers2stitch a
        join aers2stitch b
        where a.drug_name = drug1
        and b.drug_name = drug2
        and is_mode = 0
    ) scores
    left join
    (
        select type, a.stitch_id as stitch_id1, b.stitch_id as stitch_id2, 1 as interaction
        from compound_ddi.veterans
        join compound_ddi.veterans_drug2stitch a on (a.drug = drug1)
        join compound_ddi.veterans_drug2stitch b on (b.drug = drug2)
    ) ddi using (stitch_id1, stitch_id2);
    """ % prefix
    c.execute(query)
    
    data = c.fetchall()
    pairs = ["%s,%s" % (x[0], x[1]) for x in data]
    
    nm = NamedMatrix(None, pairs, ["feature"])
    pair_labels = []
    va_labels = []
    
    for i, (p1, p2, score, label, type, va) in enumerate(data):
        nm[i,0] = score
        pair_labels.append(label)
        va_labels.append(1 if va is not None else 0)
    
    lr = ml.Logistic(nm, pair_labels)
    lr.cross_validate()
    
    save_roc_plots(lr, prefix + "-pa")
    
    lr = ml.Logistic(nm, va_labels)
    lr.cross_validate()
    
    save_roc_plots(lr, prefix + "-va")
