"""
Identify the controls to match to the cases and throw out cases that don't match any controls.

"""

import os
import sys
import csv
import math
import random
import MySQLdb

import rpy2.robjects as robjects

print >> sys.stderr, "Running psm_chosen.py version 0.7"

N_BINS = 20
NCONT_MULT = 10

# c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers").cursor()
c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers_10q4").cursor()

query = """
select distinct stitch_id
from psm_scores
"""
c.execute(query)

drugs = [r[0] for r in c.fetchall()]

query = """
select distinct stitch_id
from psm_chosen
"""
c.execute(query)

completed_drugs = [r[0] for ro in c.fetchall()]

for drug in drugs:
    
    if drug in completed_drugs:
        continue
    
    query = """
    select report_id, score, cohort
    from psm_scores
    where stitch_id = '%s'
    """ % drug
    c.execute(query)
    
    report_score = c.fetchall()
    
    cases = sorted([(score, report_id, cohort) for report_id, score, cohort in report_score if cohort == 1])
    
    case_range = (cases[0][0], cases[-1][0])
    
    bin_breaks = [i*((case_range[1]-case_range[0])/N_BINS) + case_range[0] for i in range(N_BINS)]
    bin_breaks[-1] = case_range[1] + 0.01*case_range[1]
    
    controls = [(score, report_id, cohort) for report_id, score, cohort in report_score if cohort == 0]
    
    sampled_controls = []
    valid_cases = []
    warning_flag = 0
    warning_cases = 0
    
    for i in range(N_BINS-1):
        bin_min = bin_breaks[i]
        bin_max = bin_breaks[i+1]
        
        potential_cases = [x for x in cases if bin_min <= x[0] < bin_max]
        num_cases = len(potential_cases)
        if num_cases == 0:
            continue
        
        available_controls = [x for x in controls if bin_min <= x[0] < bin_max]
        if len(available_controls) > 0:
            if len(available_controls) < num_cases:
                print >> sys.stderr, "Warning: Not many controls available, may have bias."
                warning_flag += 1
                warning_cases += num_cases
            
            sampled_controls.extend([random.choice(available_controls) for i in range(num_cases*NCONT_MULT)])
            valid_cases.extend(potential_cases)
    
    print >> sys.stderr, "%s: Matched %d controls to %d cases, %d cases removed." % (drug, len(sampled_controls), len(valid_cases), len(cases)-len(valid_cases))
    
    c.execute("update psm_models set chosen_ncases = %d where stitch_id = '%s'" % (len(valid_cases), drug))
    c.execute("update psm_models set warning_flag = %d where stitch_id = '%s'" % (warning_flag, drug))
    c.execute("update psm_models set warning_cases = %d where stitch_id = '%s'" % (warning_cases, drug))
    
    for score, report_id, cohort in sampled_controls + valid_cases:
        query = "insert into psm_chosen values ('%s', %d, %d)" % (drug, report_id, cohort)
        x = c.execute(query)

# Now we have everything to create the psm_drug_events table.
"""
# Create (if not already) a event_report table.
create table event_report
select distinct report_id, umls_id
from drug_event_report;

# Compute case_n/control_n and insert drug-event pairs
insert into psm_drug_events (`stitch_id`, `umls_id`, `case_n`, `control_n`)
select stitch_id, umls_id, sum(cohort = 1) as case_n, sum(cohort = 0) as control_n
from psm_chosen
join event_report using (report_id)
group by stitch_id, umls_id;

# we don't care about drug-event associations that aren't reported.
delete from psm_drug_events where case_n = 0;

# set the totals
update psm_drug_events
join
(
	select stitch_id, sum(cohort = 1) as x, sum(cohort = 0) as y
	from psm_chosen
	group by stitch_id
) a using (stitch_id)
set case_total = a.x, control_total = a.y;

# now we are ready to run "psm_sample.py"

# after that is complete we can finish it off with these queries

update psm_drug_events
set prr = case_mu/control_mu;

# set the t statistic

update psm_drug_events
set t_statistic = (case_mu - control_mu) / sqrt((case_sd*case_sd)/100 + (control_sd*control_sd)/100);

update psm_drug_events
set df = (((case_sd*case_sd + control_sd*control_sd)/100)*((case_sd*case_sd + control_sd*control_sd)/100)) / ((case_sd*case_sd/100)*(case_sd*case_sd/100)/(99) + (control_sd*control_sd/100)*(control_sd*control_sd/100)/(99));

# p-values need to be calculated with R
library(RMySQL)
mycon <- dbConnect(MySQL(), user='root', dbname="project_aers_10q4", host="localhost", port=3306, password='enter_your_password')

data <- dbGetQuery(mycon, "select * from psm_drug_events")

pt_apply <- function(row) {
    pt(-abs(as.numeric(row[12])), as.numeric(row[13]))
}

data$pvalue = apply(data, 1, pt_apply)
write.csv(data,file="~/psm_drug_events_pvalues.csv",quote=F,row.names=F)

update psm_drug_events
join
(
	select stitch_id, count(*) num_hypoth
	from psm_drug_events
	group by stitch_id
) a using (stitch_id)
set corrected = least(1,pvalue*num_hypoth);

# Set the expected number of reports and disproportionality stats (from DuMouchel paper)
update psm_drug_events
set Eij = ((case_n+control_n)*(case_total))/(case_total+control_total);

update psm_drug_events
set RR = case_n/Eij;

update psm_drug_events
set log2rr = log2(RR);

# Create the offsides database
create table offsides
select stitch_id, d.name as drug, umls_id, e.name as event, case_n as observed, Eij as expected, RR, corrected as corr_pvalue, t_statistic
from psm_drug_events
join drugs d using (stitch_id)
join project_aers.umls2name e using (umls_id)
where corrected < 0.05
and t_statistic > 0
and RR > 2
group by stitch_id, umls_id;

"""
