<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">#!/usr/bin/env python
# encoding: utf-8
"""
evalution_table.py

Created by Nicholas Tatonetti on 2010-08-26.
Copyright (c) 2010 Stanford University. All rights reserved.
"""

import os
import sys
import MySQLdb


if __name__ == '__main__':
    
    db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers")
    c = db.cursor()
    
    # Get the gold standard
    query = """
    select stitch_id, umls_id
    from gold_drug_ae
    # from effect_medeffect.singlet_event_report_count
    """
    print &gt;&gt; sys.stderr, "Querying gold standard..."
    print c.execute(query)
    
    # Get the gold standard for pairs
    # query = """
    # select concat(stitch_id1, stitch_id2) as pair, umls_id
    # from effect_medeffect.doublet_event_report_count;
    # """
    # c.execute(query)
    
    gold_standard = dict()
    
    result = [row for row in c.fetchall()]
    
    for stitch_id,umls_id in result:
        if not gold_standard.has_key(stitch_id):
            gold_standard[stitch_id] = set()
        gold_standard[stitch_id].add(umls_id)
    
    # Get the predictions.
    
    # query = """
    # select stitch_id, umls_id, pvalue
    # from singlet_ae_pvals
    # where drug_mean &gt; bg_mean
    # """
    
    query = """
    select stitch_id, umls_id
    from sig_drug_events
    where corr_pvalue &lt; 1e-100
    # where drug_mean &gt; bg_mean
    """
    
    # query = """
    # select concat(stitch_id1, stitch_id2) as pair, umls_id, drug_mean
    # from pred_pair_events_b
    # where drug_mean &gt; bg_mean;
    # """
    
    print &gt;&gt; sys.stderr, "Querying predictions..."
    print c.execute(query)
    
    predictions = dict()
    
    result = [row for row in c.fetchall()]
    umls_ids = set()
    
    for stitch_id, umls_id in result:
        
        if stitch_id not in predictions:
            predictions[stitch_id] = set()
        
        predictions[stitch_id].add(umls_id)
        umls_ids.add(umls_id)
    
    # Calculate the Accuracy for each drug.
    # accuracy = list()
    for stitch_id, predicted_umls_ids in predictions.items():
        num_tp = len(predicted_umls_ids &amp; gold_standard[stitch_id])
        num_tn = len((umls_ids - gold_standard[stitch_id]) - predicted_umls_ids)
        acc = float(num_tp + num_tn) / float(len(umls_ids))
        recall = float(num_tp)/float(len(gold_standard[stitch_id]))
        precision = float(num_tp)/len(predicted_umls_ids)
        # accuracy.append((acc, stitch_id))
        query = "INSERT INTO eval_acc_sig_drug_events (stitch_id, accuracy, `precision`, recall, label_count, pred_count) VALUES ('%s',%.10f,%.10f,%.10f,%d,%d)" % (stitch_id, acc, precision, recall, len(gold_standard[stitch_id]), len(predicted_umls_ids))
        c.execute(query)
    
    
    </pre></body></html>