#!/usr/bin/env python
# encoding: utf-8
"""
clinical_eval.py

Created by Nicholas Tatonetti on 2011-02-19.
Copyright (c) 2011 Stanford University. All rights reserved.
"""

import os
import csv
import sys
import MySQLdb
import rpy2.robjects as robjects

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="clinical_stanford")
conn = db.cursor()

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

def name2stitch(drug_name):
    
    query = """
    select distinct stitch_id
    from effect_aers.aers2stitch
    where drug_name like "%s%s%s";
    """ % ('%', drug_name, '%')
    num_results = conn.execute(query)
    
    if num_results > 0:
        stitch_id = [x[0] for x in conn.fetchall()][0]
        if num_results > 1:
            print >> sys.stderr, "Found more %d matches for '%s', using the first one (%s)." % (num_results, drug_name, stitch_id)
        return stitch_id
    
    query = """
    select distinct stitch_id
    from med2stitch join medication_dx using (med_id)
    where brand_name like "%s%s%s"
    or generic_name like "%s%s%s";
    """ % ('%', drug_name, '%', '%', drug_name, '%')
    num_results = conn.execute(query)
    
    if num_results > 0:
        stitch_id = [x[0] for x in conn.fetchall()][0]
        if num_results > 1:
            print >> sys.stderr, "Found more %d matches for '%s', using the first one (%s)." % (num_results, drug_name, stitch_id)
        return stitch_id
    
    return None

if __name__ == '__main__':
    
    # eg. usage
    # python clinical_lab_eval.py renal_impairment CR 400 log
    # python clinical_lab_eval.py cholesterol CHOL 100 log
    
    indication = sys.argv[1]
    base_name = sys.argv[2]
    days = sys.argv[3]
    log_values = False
    try:
        if sys.argv[4] == 'log':
            log_values = True
    except IndexError:
        pass
    
    result_file_name = 'clinical_results_labs_%s_%s%s.csv' % (base_name, days, '_log' if log_values else '')
    failed_file_name = 'failed_pairs_labs_%s_%s%s.csv' % (base_name, days, '_log' if log_values else '')
    
    if len([x for x in sys.argv if x == 'restart']) > 0:
        os.unlink(os.path.join(PATH_TO_DATA_DIR, indication, result_file_name))
        os.unlink(os.path.join(PATH_TO_DATA_DIR, indication, failed_file_name))
    
    pred_path = os.path.join(PATH_TO_DATA_DIR, indication, 'clinical_preds_norankfilt_gtr0.txt')
    
    if not os.path.exists(pred_path):
        raise Exception("You must first make a prediction file and store it at %s." % pred_path)
    
    result_file_path = os.path.join(PATH_TO_DATA_DIR, indication, result_file_name)
    failed_file_path = os.path.join(PATH_TO_DATA_DIR, indication, failed_file_name)
    
    write_header = True
    completed_pairs = list()
    if os.path.exists(result_file_path):
        completed_pairs.extend([x[0] for x in csv.reader(open(result_file_path))])
        print >> sys.stderr, "Found %d completed pairs already, will skip those." % len(completed_pairs)
        write_header = False
    
    if os.path.exists(failed_file_path):
        completed_pairs.extend([x[0] for x in csv.reader(open(failed_file_path))])
    
    outfh = open(result_file_path, 'a+')
    writer = csv.writer(outfh)
    
    failed_writer = csv.writer(open(failed_file_path, 'a+'))
    
    if write_header:
        row_data = list()
        row_data.append( 'drug_pair' )
        row_data.append( 'sid1' )
        row_data.append( 'sid2' )
        row_data.append( 'anova_pvalue' )
        row_data.append( 'drug1 - combo' )
        row_data.append( 'drug2 - combo' )
        row_data.append( 'drug2 - drug1' )
        row_data.append( 'drug1 - combo pvalue' )
        row_data.append( 'drug2 - combo pvalue' )
        row_data.append( 'drug2 - drug1 pvalue' )
        writer.writerow(row_data)
    
    reader = csv.reader(open(pred_path), delimiter='\t')
    
    robjects.r("""
    library(multcomp)
    library(RMySQL)
    mycon <- dbConnect(MySQL(), user='root', dbname='clinical_stanford', host='127.0.0.1', port=3306, password='dummy_password')
    """)
    
    for rank, drug_pair, score, label, min_rank in reader:
        
        if drug_pair in completed_pairs:
            print >> sys.stderr, "Already found %s, skipping." % drug_pair
            failed_writer.writerow([drug_pair])
            continue
        
        drug1 = drug_pair.split(',')[0]
        drug2 = drug_pair.split(',')[1]
        
        sid1 = name2stitch(drug1)
        if sid1 is None:
            sid1 = name2stitch(drug1.split(' ')[0])
        
        if sid1 is None:
            print >> sys.stderr, "Failed to find a stitch id for %s, skipping '%s'" % (drug1, drug_pair)
            failed_writer.writerow([drug_pair])
            continue
        
        sid2 = name2stitch(drug2)
        if sid2 is None:
            sid2 = name2stitch(drug2.split(' ')[0])
        
        if sid2 is None:
            print >> sys.stderr, "Failed to find a stitch id for %s, skipping '%s'" % (drug2, drug_pair)
            failed_writer.writerow([drug_pair])
            continue
        
        if sid1 == sid2:
            print >> sys.stderr, "Drugs mapped to same identifier, skipping %s." % drug_pair
            continue
        
        print >> sys.stderr, "Fitting linear model for %s (%s and %s):" % (drug_pair, sid1, sid2),
        
        pvalue = None
        try:
            table = "temp_%s_labs" % base_name.lower()
            robjects.r("""
            
            drug1 <- '%s'
            drug2 <- '%s'
            table <- '%s'
            days <- %s
            
            rs <- dbSendQuery(mycon, paste("
            select patient_id, 'combo' as `group`, datediff(order_date, birth_date)/365.25 as age, race, gender, avg(a.lab_value) as baseline, avg(b.lab_value) as treatment, avg(b.lab_value)-avg(a.lab_value) as `change`
            from
            (
            	select patient_id, greatest(max(order_date1), max(order_date2)) as order_date, birth_date, ethnic_group_id as race, gender
            	from
            	(
            		select patient_id, birth_date, ethnic_group_id, gender, order_date as order_date1
            		from prescriptions
            		join clinical_stanford.patients using (patient_id)
            		join med2stitch using (med_id)
            		where stitch_id = '",drug1,"'
            	) drug1
            	join
            	(
            		select patient_id, order_date as order_date2
            		from prescriptions
            		join med2stitch using (med_id)
            		where stitch_id = '",drug2,"'
            	) drug2 using (patient_id)
            	where abs(datediff(order_date1, order_date2)) < 36
            	group by patient_id
            ) as patients
            join ",table," a using (patient_id)
            join ",table," b using (patient_id)
            where datediff(a.result_date, order_date) <= 0
            and datediff(a.result_date, order_date) > -",days,"
            and datediff(b.result_date, order_date) > 0
            and datediff(b.result_date, order_date) < ",days,"
            group by patient_id
            
            union
            
            select patient_id, 'drug1' as `group`, datediff(order_date, birth_date)/365.25 as age, race, gender, avg(a.lab_value) as baseline, avg(b.lab_value) as treatment, avg(b.lab_value)-avg(a.lab_value) as `change`
            from
            (
            	select patient_id, max(order_date) as order_date, birth_date, ethnic_group_id as race, gender
            	from prescriptions
            	join clinical_stanford.patients using (patient_id)
            	join med2stitch using (med_id)
            	where stitch_id = '",drug1,"'
            	group by patient_id
            ) as patients
            join ",table," a using (patient_id)
            join ",table," b using (patient_id)
            where datediff(a.result_date, order_date) <= 0
            and datediff(a.result_date, order_date) > -",days,"
            and datediff(b.result_date, order_date) > 0
            and datediff(b.result_date, order_date) < ",days,"
            group by patient_id
            
            union
            
            select patient_id, 'drug2' as `group`, datediff(order_date, birth_date)/365.25 as age, race, gender, avg(a.lab_value) as baseline, avg(b.lab_value) as treatment, avg(b.lab_value)-avg(a.lab_value) as `change`
            from
            (
            	select patient_id, max(order_date) as order_date, birth_date, ethnic_group_id as race, gender
            	from prescriptions
            	join clinical_stanford.patients using (patient_id)
            	join med2stitch using (med_id)
            	where stitch_id = '",drug2,"'
            	group by patient_id
            ) as patients
            join ",table," a using (patient_id)
            join ",table," b using (patient_id)
            where datediff(a.result_date, order_date) <= 0
            and datediff(a.result_date, order_date) > -",days,"
            and datediff(b.result_date, order_date) > 0
            and datediff(b.result_date, order_date) < ",days,"
            group by patient_id;
            ", sep=""))
            
            data <- fetch(rs, n=-1)
            
            data$group <- as.factor(data$group)
            data$race <- as.factor(data$race)
            data$gender <- as.factor(data$gender)
            """ % (sid1, sid2, table, days))
            
            if log_values:
                x = robjects.r("""
                fit0 <- lm(log(treatment) ~ log(baseline) + age + race + gender, data)
                fit1 <- lm(log(treatment) ~ log(baseline) + age + race + gender + group, data)
                anova(fit0, fit1)
                """)
                print >> sys.stderr, robjects.r("summary(fit1)")
            else:
                x = robjects.r("""
                fit0 <- lm(treatment ~ baseline + age + race + gender, data)
                fit1 <- lm(treatment ~ baseline + age + race + gender + group, data)
                anova(fit0, fit1)
                """)
                print >> sys.stderr, robjects.r("summary(fit1)")
            
            pvalue = x[5][1]
            
            y = robjects.r("""
            summary(glht(fit1, linfct = mcp(group = "Tukey")))
            """)
            
            print >> sys.stderr, y
            
            test = y[list(y.getnames()).index('test')]
            
            estimates = list(test[list(test.getnames()).index('coefficients')])
            pvalues = list(test[list(test.getnames()).index('pvalues')])
            
        except:
            pvalue = None
            estimates = None
            pvalues = None
            
        
        print >> sys.stderr, pvalue, estimates, pvalues
        
        if not pvalue is None:
            row_data = list()
            row_data.append( drug_pair )
            row_data.append( sid1 )
            row_data.append( sid2 )
            row_data.append( pvalue )
            row_data.extend( estimates )
            row_data.extend( pvalues )
            writer.writerow(row_data)
            outfh.flush()

