<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">"""
Evaluate the performance of the predictions.


USAGE
=====
Run this script as follows:

python evaluate.py [rareness_cuttoffs] [query_types] [evaluation methods]

e.g.
python evaluate.py 800 TStat,PRR class,precision
This runs the evaluation on the entire gold standard (all rareness levels)
using the T-Statistic and Positive Reporting Ration methods. It calculates
the AUROC, AUPR as well as a precision plot over the value.

"""

import os
import csv
import sys
import math
import numpy
import random
import MySQLdb
import operator
import rpy2.robjects as robjects

from namedmatrix import NamedMatrix
from pyweka import MachineLearning as ml

def save_prcurve(file_name, lr):
    fh = open(file_name, 'w')
    writer = csv.writer(fh)
    writer.writerows(zip(lr.plot_data[3], lr.plot_data[2]))
    fh.close()

def save_roccurve(file_name, lr):
    fh = open(file_name, 'w')
    writer = csv.writer(fh)
    writer.writerows(zip(lr.plot_data[1], lr.plot_data[0]))
    fh.close()

def cont_table2(mean, labels, preds):
    t = []
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 1 and preds[i][2] &gt; mean]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 1 and preds[i][2] &lt;= mean]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 0 and preds[i][2] &gt; mean]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 0 and preds[i][2] &lt;= mean]) )
    return tuple(t)

def cont_table(pvalue, labels, preds):
    t = []
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 1 and preds[i][2] &lt; pvalue]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 1 and preds[i][2] &gt;= pvalue]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 0 and preds[i][2] &lt; pvalue]) )
    t.append( sum([1 for i in range(len(labels)) if labels[i] == 0 and preds[i][2] &gt;= pvalue]) )
    return tuple(t)
        
def build_curves(preds, cutoffs, labels):
    
    if len(preds[0]) &gt; 3:
        for i in range(len(preds)):
            if preds[i][3] == 0:
                preds[i][2] = 0
    
    FPRs = list()
    TPRs = list()
    PRs = list()
    REs = list()
    auroc = 0.0
    aupr = 0.0
    for p,cutoff in enumerate(cutoffs):
        num_positive = len( filter( lambda x: x[2] &gt;= cutoff, preds) )
        TP = sum([labels[i] for i in range(len(preds)) if preds[i][2] &gt;= cutoff])
        FP = num_positive - TP
        
        num_negative = len( filter( lambda x: x[2] &lt; cutoff, preds) )
        FN = sum([labels[i] for i in range(len(preds)) if preds[i][2] &lt; cutoff])
        TN = num_negative - FN
        
        TPR = TP / float(TP + FN)
        FPR = FP / float(FP + TN)
        PR = TP / float(num_positive if not num_positive == 0 else 1)
        RE = TPR
        
        if p &gt; 0:
            auroc += (FPRs[-1]-FPR)*(TPR)+0.5*(FPRs[-1]-FPR)*(TPRs[-1]-TPR)
            aupr += (REs[-1]-RE)*(PR)+0.5*(REs[-1]-RE)*(PRs[-1]-PR)
            
        TPRs.append(TPR)
        FPRs.append(FPR)
        PRs.append(PR)
        REs.append(RE)
    
    return ((auroc, aupr), (TPRs, FPRs), (PRs, REs))

def recall_plot(preds, gold_standard, cutoffs):
    
    recalls = []
    
    for cutoff in cutoffs:
        
        pred_dict = dict()
        for stitch_id, umls_id, score in preds:
            if score &gt; cutoff:
                if not pred_dict.has_key(stitch_id):
                    pred_dict[stitch_id] = set()
                pred_dict[stitch_id].add(umls_id)
        
        recall_list = []
        for stitch_id, events in gold_standard.items():
            if stitch_id in pred_dict.keys():
                recall = len(events&amp;pred_dict[stitch_id]) / float(len(events))
                recall_list.append( recall )
        
        recalls.append(recall_list)
    
    return recalls

def precision_plot(preds, labels):
    """
    Plot precision over top X predictions.
    Assumes that pred is ordered.
    """
    num_tp = 0
    prec = []
    for i,(pred,label) in enumerate(zip(preds,labels)):
        if label == 1:
            num_tp += 1
        
        prec.append(float(num_tp)/float(i+1))
    
    return prec

def balance_preds(preds, labels):
    
    positives = [preds[i] for i in range(len(preds)) if labels[i] == 1]
    negatives = [preds[i] for i in range(len(preds)) if labels[i] == 0]
    
    size = len(positives) if len(positives) &lt; len(negatives) else len(negatives)
    
    return ( random.sample(positives, size) + random.sample(negatives, size), [1]*size + [0]*size)

def save_results(results):
    """
    results is a dictionary of different gold standards (e.g. different rarenes sof events)
    and different scoring methods (e.g. t-stat, prr, drug_mean, etc.)
    """
    fh_roc = open(os.path.expanduser('~/Dropbox/deconvolution_results_compare_roc.csv'),'w')
    fh_pr = open(os.path.expanduser('~/Dropbox/deconvolution_results_compare_pr.csv'),'w')
    
    roc = csv.writer(fh_roc)
    pr = csv.writer(fh_pr)
    
    gold_types = sorted(set([x.split('-')[0] for x in results.keys()]))
    query_types = sorted(set([x.split('-')[1] for x in results.keys()]))
    
    roc.writerow(['rareness'] + query_types)
    pr.writerow(['rareness'] + query_types)
    for gold in gold_types:
        roc.writerow([gold] + [results['%s-%s' % (gold,name)].results['AUROC'] for name in query_types])
        pr.writerow([gold] + [results['%s-%s' % (gold,name)].results['AUPR'] for name in query_types])
    
    fh_roc.close()
    fh_pr.close()

def save_roc_plots(results):
    
    for key in results.keys():
        lr = results[key]
        fh_roc = open(os.path.expanduser('~/Dropbox/deconv_result_roc_%s.csv' % key), 'w')
        roc = csv.writer(fh_roc)
        roc.writerows(zip(lr.plot_data[1], lr.plot_data[0]))
        fh_roc.close()


def save_precision_plots(results):
    
    for key in results.keys():
        fh = open(os.path.expanduser('~/Dropbox/deconvolution_precision_plot_%s.csv' % key),'w')
        writer = csv.writer(fh)
        writer.writerows(results[key])
        fh.close()

def save_precision_means(results):
    
    fh = open(os.path.expanduser('~/Dropbox/deconvolution_precision_plot_compare_%s.csv' % '_'.join(results.keys())),'w')
    writer = csv.writer(fh)
    
    gold_types = sorted(set([int(x.split('-')[0]) for x in results.keys()]))
    query_types = sorted(set([x.split('-')[1] for x in results.keys()]))
    
    writer.writerow(['rarenss'] + query_types)
    for gold in gold_types:
        means = list()
        for query in query_types:
            key = '%s-%s' % (gold,query)
            means.append( results[key][-1][1] )
        
        writer.writerow([gold] + means)
    
    fh.close()

if __name__ == '__main__':
    
    eps = 1.0e-323
    
    db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers")
    # db = MySQLdb.connect(host="127.0.0.1", port=3306, user="root", passwd="enter_your_password",db="project_aers")
    c = db.cursor()
    
    # Get command line options for the gold standard.
    gold_types = sys.argv[1].split(',')
    
    # Get command line options for the data sources.
    if sys.argv[2].lower() == 'all':
        query_types = pred_queries.keys()
    else:
        query_types = sys.argv[2].split(',')
    
    # Get the command line options for the evaluation methods.
    eval_types = sys.argv[3].split(',')
    
    results = dict()
    for gold_type in gold_types:
        
        print &gt;&gt; sys.stderr, "Querying gold standard (%s)..." % gold_type
        
        # Get the gold standard
        if gold_type == 'medeffect':
            query = """
            select stitch_id, umls_id, num_reports
            from gold_drug_ae_me
            """
        elif gold_type == 'medeffect-sider':
            query = """
            select stitch_id, umls_id, num_reports
            from gold_drug_ae_me
            left join gold_drug_ae sider using (stitch_id, umls_id)
            where sider.gold is NULL
            """
        elif gold_type == 'sider':
            query = """
            select stitch_id, umls_id, gold
            from gold_drug_ae
            """
        elif gold_type == 'singlet':
            query = """
            select stitch_id, umls_id, 1 as gold
            from singlet_ae_pvals_z
            where pvalue &lt; 1e-10
            """
        elif gold_type.find('rare') != -1:
            rarity = gold_type.split('e')[1]
            query = """
            select stitch_id, umls_id, gold
            from gold_drug_ae_freqs
            where max_upper &lt;= %s
            """ % rarity
        elif gold_type == 'future':
            query = """
            select stitch_id, umls_id, 1 as gold
            from gold_drug_ae_future
            """
        else:
            query = """
            select stitch_id, umls_id, drug_count
            from gold_drug_ae
            join (
                select umls_id, count(stitch_id) as drug_count
                from gold_drug_ae
                group by umls_id
            ) dc using (umls_id)
            where drug_count &lt; %s
            """ % gold_type
        
        print &gt;&gt; sys.stderr, "Querying the GOLD standard (%s)...\n%s" % (gold_type, query)
        print c.execute(query)
        
        gold_standard = dict()
        
        result = [row for row in c.fetchall()]
        event_drugcounts = dict()
        
        for stitch_id,umls_id,drug_count in result:
            if not gold_standard.has_key(stitch_id):
                gold_standard[stitch_id] = set()
            gold_standard[stitch_id].add(umls_id)
            
            if not umls_id in event_drugcounts:
                event_drugcounts[umls_id] = drug_count
        
        gold_umls_ids = set(event_drugcounts.keys())
        
        pred_queries = {
        'EBGM-log2RR': """
            select stitch_id, umls_id, a.log2RR
            from ebgm_pred_drug_events a
            join psm_drug_events b using (stitch_id, umls_id)
            where (Nij+a.Eij) &gt;= 5
            and a.log2RR &gt; 0
            """,
        'PSM-log2RR': """
            select stitch_id, umls_id, a.log2RR
            from psm_drug_events a
            join ebgm_pred_drug_events b using (stitch_id, umls_id)
            where (Nij+b.Eij) &gt;= 5
            and a.log2RR &gt; 0
            """,
        'PSM-prr': """
            select stitch_id, umls_id, prr
            from psm_drug_events
            where prr &gt; 1
            """,
        'PSM-tstat': """
            select stitch_id, umls_id, t_statistic
            from psm_drug_events a
            join ebgm_pred_drug_events b using (stitch_id, umls_id)
            where a.log2RR &gt; 0
            and b.log2RR &gt; 0
            """,
        'EMR-t':"""
            select stitch_id, umls_id, t_statistic
            from effect_stanford.pred_drug_events_e
            where t_statistic &gt; 0
            """,
        'EMR-prr':"""
            select stitch_id, umls_id, prr
            from effect_stanford.pred_drug_events_e
            and prr &gt; 1
            """,
        'EMR-prop-prr':"""
            select stitch_id, umls_id, prr
            from effect_stanford.pred_drug_events_prop
            where prr &gt; 1
            """,
        'DrugMean': """
            select stitch_id, umls_id, drug_mean
            from pred_drug_events_b
            """,
        'TStat-b': """
            select stitch_id, umls_id, t_statistic
            from pred_drug_events_b
            where t_statistic &gt; 0
            """,
        "TStat-null": """
            select stitch_id, umls_id, t_statistic
            from pred_drug_events_null
            where t_statistic &gt; 0
            """,
        'TStat-e': """
            select stitch_id, umls_id, t_statistic
            from pred_drug_events_e
            where t_statistic &gt; 0
            """,
        'TStat-e5-bgscore': """
            select stitch_id, umls_id, abs((mean-bg_mean)/bg_sd) as bgscore
            from pred_drug_events_e5
            join event_backgrounds using (umls_id)
            where t_statistic &gt; 0
            and bg_sd &gt; 0;
            """,
        'TStat-e5-bgscore+T': """
            select stitch_id, umls_id, abs((mean-bg_mean)/bg_sd) as bgscore, t_statistic
            from pred_drug_events_e5
            join event_backgrounds using (umls_id)
            where t_statistic &gt; 0
            and bg_sd &gt; 0;
            """,
        'Ratio-e5': """
            select stitch_id, umls_id, drug_num/bg_num as ratio
            from pred_drug_events_e5
            where t_statistic &gt; 0;
            """,
        'TStat-e5': """
            select stitch_id, umls_id, t_statistic
            from pred_drug_events_e5
            where t_statistic &gt; 0
            """,
        'TStat-e5norm': """
            select stitch_id, umls_id, t_dnorm_gtr0
            from pred_drug_events_e5
            where t_statistic &gt; 0
            """,
        'Chisq-e5': """
            select stitch_id, umls_id, chisq
            from pred_drug_events_e5
            where drug_mean &gt; bg_mean
            """,
        'PRR-e-ratio': """
            select stitch_id, umls_id, p.prr/e.prr as ratio
            from pred_drug_events_e5 e
            join prop_pred_drug_events p using (stitch_id, umls_id)
            where (e.prr &gt; 2 or p.prr &gt; 2)
            """,
        'PRR-e': """
            select stitch_id, umls_id, prr
            from pred_drug_events_e
            where t_statistic &gt; 0
            and prr &gt; 1
            """,
        'PRR-e5': """
            select stitch_id, umls_id, prr
            from pred_drug_events_e5
            where t_statistic &gt; 0
            and prr &gt; 1
            """,
        'PRR-e-ratio': """
            select stitch_id, umls_id, e.prr, p.prr/e.prr
            from pred_drug_events_e e
            join prop_pred_drug_events p using (stitch_id, umls_id)
            where e.prr &gt; 1
            """,
        'PRR': """
            select stitch_id, umls_id, prr
            from prop_pred_drug_events
            where prr &gt; 1
            """,
        'EScore-e5': """
            select stitch_id, umls_id, log10(e_score_de)
            from pred_drug_events_e5
            where e_score_de is not null
            """,
        'IC': """
            select stitch_id, umls_id, ic
            from prop_pred_drug_events
            where a &gt; 3;
            """,
        'Random-DrugMean': """
            select stitch_id, umls_id, drug_mean
            from pred_drug_events_b_random
            where drug_mean &gt; bg_mean
            """,
        'Random-TStat': """
            select stitch_id, umls_id, t_statistic
            from pred_drug_events_b_random
            where t_statistic &gt; 0
            """,
        'Random-PRR': """
            select stitch_id, umls_id, prr
            from prr_pred_drug_events_random
            """,
        'Random-TStat+PRR': """
            select stitch_id, umls_id, t_statistic, prr
            from pred_drug_events_b_random
            join prr_pred_drug_events_random using (stitch_id, umls_id)
            where t_statistic &gt; 0
            """,
        'MostCommon': """
            select stitch_id, umls_id, num_reports
            from pred_drug_events_b
            join event_reportcount using (umls_id)
            """,
        # Using 10q4 database.
        'PRR-e5-10q4': """
            select stitch_id, umls_id, prr
            from project_aers_10q4.pred_drug_events_e5
            where t_statistic &gt; 0
            and prr &gt; 1
            """,
        'TStat-e5-10q4': """
            select stitch_id, umls_id, t_statistic
            from project_aers_10q4.pred_drug_events_e5
            where t_statistic &gt; 0
            """,
        'PSM-10q4': """
            select stitch_id, umls_id, bg_correction
            from project_aers_10q4.psm_drug_events
            where t_statistic &gt; 0
            """
        }
        
        for query_type in query_types:
            if not query_type in pred_queries:
                raise Exception("Query type is invalid: %s" % query_type)
        
        for query_type in query_types:
            print &gt;&gt; sys.stderr, "Querying predictions (%s)...\n%s" % (query_type, pred_queries[query_type])
            print c.execute(pred_queries[query_type])
            
            preds = []
            labels = []
            rows = c.fetchall()
            pred_dict = dict()
            label_dict = dict()
            for row in rows:
                stitch_id, umls_id = row[:2]
                values = row[2:]
                # Check if we even have data on this Stitch ID.
                if umls_id in gold_umls_ids and stitch_id in gold_standard:
                # if stitch_id in gold_standard:
                    labels.append( 1 if umls_id in gold_standard[stitch_id] else 0 )
                    preds.append( [stitch_id, umls_id] + map(float,values) )
                    
                    pred_dict["%s,%s" % (stitch_id, umls_id)] = map(float,values)
                    label_dict["%s,%s" % (stitch_id, umls_id)] = labels[-1]
            
            print &gt;&gt; sys.stderr, "sum(labels) = %d, ratio of positives = %f" % (sum(labels), sum(labels)/float(len(labels)))
            
            if 'class' in eval_types:
                ####### Calculate the performance. ########
                print &gt;&gt; sys.stderr, "Calculating the performance of %s." % query_type
                feat = NamedMatrix(None, pred_dict.keys(), map(lambda x: "Feature_%s" % x, range(len(values))))
                feat_labels = [label_dict[key] for key in feat.rownames]
                for i,key in enumerate(feat.rownames):
                    for j in range(len(pred_dict[key])):
                        feat[i,j] = pred_dict[key][j]
                
                lr = ml.Logistic(feat, feat_labels)
                print &gt;&gt; sys.stderr, lr.cross_validate()
                
                results['%s-%s' % (gold_type,query_type)] = lr
            
            if 'prec_v_numpred' in eval_types:
                # plot the precision over the number of predictions
                values = sorted([(val, label) for (sid, uid, val), label in zip(preds, labels)])
                values.reverse()
                
                plot_data = list()
                for i in range(0, len(values), 100):
                    precision = sum([l for v,l in values[:(i+1)]])/float(i+1)
                    plot_data.append((i+1, plot_data))
                
                results['%s-%s' % (gold_type, query_type)] = plot_data
            
            if 'precision' in eval_types:
                precision_plot = list()
                ####### Precision Plots ########
                if pred_queries[query_type].count('pvalue') &gt; 0:
                    print &gt;&gt; sys.stderr, "Computing precision over p value..."
                    rng = range(2,220,10)
                    for i in rng:
                        called_positive = [labels[j] for j,(c,u,p) in enumerate(preds) if p &lt; float("1e-%d" % i)]
                        precision_plot.append( [ i, sum(called_positive)/float(len(called_positive)), len(called_positive) ] )
                else:
                    print &gt;&gt; sys.stderr, "Computing precision over value range..."
                    min_score = min([x[2] for x in preds])
                    max_score = max([x[2] for x in preds])
                    delta = (max_score-min_score)/100.0
                    rng = numpy.arange(min_score+delta, max_score+delta, delta)
                    for i in rng:
                        called_positive = [labels[j] for j,(cid,uid,p) in enumerate(preds) if p &gt; i]
                        if float(len(called_positive)) == 0:
                            precision_plot.append( [ i, 0, len(called_positive) ] )
                        else:
                            precision_plot.append( [ i, sum(called_positive)/float(len(called_positive)), len(called_positive) ] )
                
                results['%s-%s' % (gold_type, query_type)] = precision_plot
            
            # All of this is so you can specify the top N in the command line.
            # e.g. top-precision-1000 # for top 1,000.
            top_prec_eval_types = [x for x in eval_types if not x.find('top-precision') == -1]
            if len(top_prec_eval_types) &gt; 0:
                top_n = int(top_prec_eval_types[0].split('-')[-1])
                print &gt;&gt; sys.stderr, "Computing top %d precision plot..." % top_n
                
                sorted_preds = sorted([(value, drug_event_key.split(',')) for drug_event_key,value in pred_dict.items()])
                sorted_preds.reverse()
                precision_plot_data = list()
                num_true = 0
                for total, (value, (drug_id, event_id)) in enumerate(sorted_preds[:top_n]):
                    if event_id in gold_standard[drug_id]:
                        num_true += 1
                    
                    precision = num_true / float(total+1.0)
                    precision_plot_data.append([total, precision])
                
                results['%s-%s' % (gold_type, query_type)] = precision_plot_data
    
    ######## Odds Plots #########
    # print &gt;&gt; sys.stderr, "Building p-value odds plot..."
    # rng = range(2,200,10)
    # for i in rng:
    #     odds = robjects.r("fisher.test(matrix(c%s, nrow=2, byrow=T))$estimate" % str(cont_table( float("1e-%d" % i), labels, preds)))[0]
    #     print i,odds
    
    # print &gt;&gt; sys.stderr, "Building value based odds plot..."
    # min_score = min([x[2] for x in preds])
    # max_score = max([x[2] for x in preds])
    # delta = (max_score-min_score)/30.0
    # rng = numpy.arange(min_score, max_score+delta, delta)
    # for i in rng:
    #     odds = robjects.r("fisher.test(matrix(c%s, nrow=2, byrow=T))$estimate" % str(cont_table2( i, labels, preds)))[0]
    #     print i,odds
    
    # P-Value Odds Plot R Code
    # print "odds = rep(0,20)"
    # print "vals = c(%s)" % ",".join(map(str,rng))
    # for j,i in enumerate(rng):
    #     print "x = c%s\nodds[%d] &lt;- fisher.test( matrix(x, nrow=2, byrow=T) )$estimate" % (str(cont_table( float("1e-%d" % i) )), j+1)
    # 
    # print """plot(vals,odds, main="Odds of being a True Drug-Adverse Event Pair\nUsing Singlet Reports", xlab="-log10 of the p-value", type="o", ylab="Odds Ratio")"""
    
    # Drug Mean Odds Plot
    # rng = map(lambda x: x/400., range(0,20))
    # print "vals = c(%s)" % ",".join(map(str,rng))
    # for j,i in enumerate(rng):                                        
    #     print "x = c%s\nodds[%d] &lt;- fisher.test( matrix(x, nrow=2, byrow=T) )$estimate" % (str(cont_table2( i )), j+1)

</pre></body></html>