"""
Sample the PSM chosen reports to generate variance estimates of the reproting frequencies.

"""

import os
import sys
import numpy
import random
import MySQLdb

import rpy2.robjects as robjects

print >> sys.stderr, "Running psm_sample.py version 0.8"

N_SAMPLES = 100
SAMPLE_SIZE = 0.1 # 10%

# c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers").cursor()
c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers_10q4").cursor()

query = """
select distinct stitch_id
from psm_drug_events
where case_mu is null
"""
c.execute(query)
drugs = [r[0] for r in c.fetchall()]

print >> sys.stderr, "Found %d drugs that need to be completed." % len(drugs)

event2report = dict()
query = """
select report_id, umls_id
from event_report
"""
c.execute(query)

for rid, uid in c.fetchall():
    if not uid in event2report:
        event2report[uid] = set()
    
    event2report[uid].add(rid)

# result_file = open('/tmp/psm_sample_results.sql','w')

for i,drug in enumerate(drugs):
    
    print >> sys.stderr, "Working on drug %d of %d." % (i+1, len(drugs))
    
    query = """
    select report_id, cohort
    from psm_chosen
    where stitch_id = '%s'
    """ % drug
    c.execute(query)
    
    print >> sys.stderr, ".",
    
    cohorts = c.fetchall()
    cases = [report_id for report_id,cohort in cohorts if cohort == 1]
    controls = [report_id for report_id,cohort in cohorts if cohort == 0]
    
    query = """
    select umls_id
    from psm_drug_events
    where case_mu is null
    and stitch_id = '%s'
    """ % drug
    c.execute(query)
    events = [r[0] for r in c.fetchall()]
    
    print >> sys.stderr, ".",
    
    case_mus = dict()
    control_mus = dict()
    
    for i in range(N_SAMPLES):
        
        case_sample = set(random.sample(cases, max(1,int(SAMPLE_SIZE*len(cases)))))
        control_sample = set(random.sample(controls, int(SAMPLE_SIZE*len(controls))))
        
        for uid in events:
            if not uid in case_mus:
                case_mus[uid] = list()
            
            if not uid in control_mus:
                control_mus[uid] = list()
            
            case_mus[uid].append( len(event2report[uid] & case_sample)/float(len(case_sample)) )
            control_mus[uid].append( len(event2report[uid] & control_sample)/float(len(control_sample)) )
    
    print >> sys.stderr, ".",
    
    for uid in events:
        case_mu = numpy.mean(case_mus[uid])
        case_sd = numpy.std(case_mus[uid])
        
        cont_mu = numpy.mean(control_mus[uid])
        cont_sd = numpy.std(control_mus[uid])
        
        # print >> result_file, "update psm_drug_events set case_mu = %f, case_sd = %f, control_mu = %f, control_sd = %f where stitch_id = '%s' and umls_id = '%s';" % (case_mu, case_sd, cont_mu, cont_sd, drug, uid)
        c.execute("update psm_drug_events set case_mu = %f, case_sd = %f, control_mu = %f, control_sd = %f where stitch_id = '%s' and umls_id = '%s'" % (case_mu, case_sd, cont_mu, cont_sd, drug, uid))
    
    print >> sys.stderr, "|"
    
