#!/usr/bin/env python
# encoding: utf-8
"""
evaluate_medeffect.py

Evaluate the logistic regression scores using the MedEffect database.

Created by Nicholas Tatonetti on 2011-02-15.
Copyright (c) 2011 Stanford University. All rights reserved.
"""

import os
import sys
import csv
import numpy
import MySQLdb
import operator
import rpy2.robjects as robjects
from pyweka import MachineLearning as ml
from namedmatrix import NamedMatrix

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="effect_aers")
conn = db.cursor()

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

# renal impairment
INDICATION = 'renal_impairment'
ADVERSE_EVENT = "'C0341697'"

# hypertension
INDICATION = 'hypertension'
ADVERSE_EVENT = "'C0020538'"


if __name__ == '__main__':
    
    conn.execute("select drug_name, stitch_id from aers2stitch;")
    
    aers2stitch = dict()
    for drug, sid in conn.fetchall():
        aers2stitch[drug] = sid
    
    # read in singlets.
    singlet_file = [f for f in os.listdir(os.path.join(PATH_TO_DATA_DIR, INDICATION)) if f.find('_drugs_') != -1][0]
    singlets = [row for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, INDICATION, singlet_file)))][1:]
    
    preds = dict()
    for drug, score, label in singlets:
        sid = aers2stitch.get(drug, None)
        if not sid is None:
            if not sid in preds or score > preds[sid]:
                preds[sid] = score
    
    query = """
    select stitch_id, max(umls_id in (%s)) as label
    from effect_medeffect.singlet_event_report_count
    group by stitch_id;
    """ % ADVERSE_EVENT
    conn.execute(query)
    
    gold = dict()
    for sid, label in conn.fetchall():
        gold[sid] = label
    
    drugs = set(gold.keys()) & set(preds.keys())
    
    feat = NamedMatrix(None, drugs, ['score'])
    for i, sid in enumerate(drugs):
        feat[i,0] = preds[sid]
    
    labels = [gold[sid] for sid in drugs]
    
    lr = ml.Logistic(feat, labels)
    lr.cross_validate()
    
    