#!/usr/bin/env python
# encoding: utf-8
"""
clinical_eval.py

Created by Nicholas Tatonetti on 2011-02-19.
Copyright (c) 2011 Stanford University. All rights reserved.
"""

import os
import csv
import sys
import MySQLdb

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="clinical_stanford")
conn = db.cursor()

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

def name2stitch(drug_name):
    
    query = """
    select distinct stitch_id
    from effect_aers.aers2stitch
    where drug_name like "%s%s%s";
    """ % ('%', drug_name, '%')
    num_results = conn.execute(query)
    
    if num_results > 0:
        stitch_id = [x[0] for x in conn.fetchall()][0]
        if num_results > 1:
            print >> sys.stderr, "Found more %d matches for '%s', using the first one (%s)." % (num_results, drug_name, stitch_id)
        return stitch_id
    
    query = """
    select distinct stitch_id
    from med2stitch join medication_dx using (med_id)
    where brand_name like "%s%s%s"
    or generic_name like "%s%s%s";
    """ % ('%', drug_name, '%', '%', drug_name, '%')
    num_results = conn.execute(query)
    
    if num_results > 0:
        stitch_id = [x[0] for x in conn.fetchall()][0]
        if num_results > 1:
            print >> sys.stderr, "Found more %d matches for '%s', using the first one (%s)." % (num_results, drug_name, stitch_id)
        return stitch_id
    
    return None

def single_cases(sid, icd9s, pair_patients):
    
    if len(icd9s) == 1:
        where_clause = "(icd9 like '%s')" % icd9s[0]
    elif len(icd9s) > 1:
        where_clause = "(%s)" % " or ".join(["icd9 like '%s'" % x for x in icd9s])
    
    query = """
    select patient_id, min(datediff(diag_date, first_order))
    from
    (
        select patient_id, min(order_date) as first_order
        from clinical_stanford.med2stitch
        join clinical_stanford.prescriptions using (med_id)
        where stitch_id = '%s'
        group by patient_id
    ) patients
    left join
    (
        select patient_id, diag_date
        from clinical_stanford.diagnoses
        where %s
    ) diags using (patient_id)
    where (datediff(diag_date, first_order) > 0 or diag_date is NULL)
    group by patient_id;
    """ % (sid, where_clause)
    num_results = conn.execute(query)
    other = [x for x in conn.fetchall() if not x[0] in pair_patients]
    
    inflicted = len([d for p,d in other if not d is None])
    
    return (len(other), inflicted)


def pair_cases(sid1, sid2, icd9s):
    
    if len(icd9s) == 1:
        where_clause = "(icd9 like '%s')" % icd9s[0]
    elif len(icd9s) > 1:
        where_clause = "(%s)" % " or ".join(["icd9 like '%s'" % x for x in icd9s])
    
    # print >> sys.stderr, where_clause
    
    query = """
    select patient_id, min(datediff(diag_date, order_date))
    from
    (
        select patient_id, least(max(order_date1), max(order_date2)) as order_date
        from
        (
            select patient_id, order_date as order_date1
            from prescriptions
            join med2stitch using (med_id)
            where stitch_id = '%s'
        ) drug1
        join
        (
            select patient_id, order_date as order_date2
            from prescriptions
            join med2stitch using (med_id)
            where stitch_id = '%s'
        ) drug2 using (patient_id)
        where abs(datediff(order_date1, order_date2)) < 36
        group by patient_id
    ) patients
    left join
    (
        select patient_id, diag_date
        from clinical_stanford.diagnoses
        where %s
    ) diags using (patient_id)
    where (datediff(diag_date, order_date) > 0 or diag_date is NULL)
    group by patient_id;
    """ % (sid1, sid2, where_clause)
    num_results = conn.execute(query)
    pair_data = [x for x in conn.fetchall()]
    pair_patients = [x[0] for x in pair_data]
    
    inflicted = len([d for p,d in pair_data if not d is None])
    
    return (pair_patients, num_results, inflicted)

if __name__ == '__main__':
    
    # eg. usage
    # python clinical_eval.py renal_impairment 58%
    # python clinical_eval.py hypertension 401%,402%,404%,405%
    
    indication = sys.argv[1]
    icd9s = sys.argv[2].split(',')
    
    if len(icd9s) == 0:
        raise Exception("You must enter some ICD9 codes to use as the outcome variable.")
    
    pred_path = os.path.join(PATH_TO_DATA_DIR, indication, 'clinical_preds.txt')
    
    if not os.path.exists(pred_path):
        raise Exception("You must first make a prediction file and store it at %s." % pred_path)
    
    result_file_path = os.path.join(PATH_TO_DATA_DIR, indication, 'clinical_results.csv')
    failed_file_path = os.path.join(PATH_TO_DATA_DIR, indication, 'failed_pairs.csv')
    
    write_header = True
    completed_pairs = list()
    if os.path.exists(result_file_path):
        completed_pairs.extend([x[0] for x in csv.reader(open(result_file_path))])
        print >> sys.stderr, "Found %d completed pairs already, will skip those." % len(completed_pairs)
        write_header = False
    
    if os.path.exists(failed_file_path):
        completed_pairs.extend([x[0] for x in csv.reader(open(failed_file_path))])
    
    outfh = open(result_file_path, 'a+')
    writer = csv.writer(outfh)
    
    failed_writer = csv.writer(open(failed_file_path, 'a+'))
    
    if write_header:
        row_data = list()
        row_data.append( 'drug_pair' )
        row_data.append( 'sid1' )
        row_data.append( 'sid2' )
        row_data.append( 'pair_total' )
        row_data.append( 'pair_inflicted' )
        row_data.append( 'pair_freq' )
        row_data.append( 'sid1_total' )
        row_data.append( 'sid1_inflicted' )
        row_data.append( 'sid1_freq' )
        row_data.append( 'sid2_total' )
        row_data.append( 'sid2_inflicted' )
        row_data.append( 'sid2_freq' )
        row_data.append( 'odds1' )
        row_data.append( 'odds2' )
        writer.writerow(row_data)
    
    reader = csv.reader(open(pred_path), delimiter='\t')
    
    for rank, drug_pair, score, label, min_rank in reader:
        
        if drug_pair in completed_pairs:
            print >> sys.stderr, "Already found %s, skipping." % drug_pair
            failed_writer.writerow([drug_pair])
            continue
        
        drug1 = drug_pair.split(',')[0]
        drug2 = drug_pair.split(',')[1]
        
        sid1 = name2stitch(drug1)
        if sid1 is None:
            sid1 = name2stitch(drug1.split(' ')[0])
        
        if sid1 is None:
            print >> sys.stderr, "Failed to find a stitch id for %s, skipping '%s'" % (drug1, drug_pair)
            failed_writer.writerow([drug_pair])
            continue
        
        sid2 = name2stitch(drug2)
        if sid2 is None:
            sid2 = name2stitch(drug2.split(' ')[0])
        
        if sid2 is None:
            print >> sys.stderr, "Failed to find a stitch id for %s, skipping '%s'" % (drug2, drug_pair)
            failed_writer.writerow([drug_pair])
            continue
        
        pair_patients, pair_total, pair_inflicted = pair_cases(sid1, sid2, icd9s)
        sid1_total, sid1_inflicted = single_cases(sid1, icd9s, pair_patients)
        sid2_total, sid2_inflicted = single_cases(sid2, icd9s, pair_patients)
        
        if pair_total == 0 or pair_inflicted == 0 or sid1_inflicted == 0 or sid2_inflicted == 0:
            print >> sys.stderr, "Not enough data for %s, skipping." % drug_pair
            failed_writer.writerow([drug_pair])
            continue
        
        pair_freq = pair_inflicted/float(pair_total)
        sid1_freq = sid1_inflicted/float(sid1_total)
        sid2_freq = sid2_inflicted/float(sid2_total)
        
        odds1 = pair_freq/sid1_freq
        odds2 = pair_freq/sid2_freq
        
        row_data = list()
        row_data.append( drug_pair )
        row_data.append( sid1 )
        row_data.append( sid2 )
        row_data.append( pair_total )
        row_data.append( pair_inflicted )
        row_data.append( pair_freq )
        row_data.append( sid1_total )
        row_data.append( sid1_inflicted )
        row_data.append( sid1_freq )
        row_data.append( sid2_total )
        row_data.append( sid2_inflicted )
        row_data.append( sid2_freq )
        row_data.append( odds1 )
        row_data.append( odds2 )
        
        writer.writerow(row_data)
        outfh.flush()
    
    

