"""
MEDEFFECT
Computes the correlated drugs for each drug.
"""

import csv
import sys
import numpy
import MySQLdb
import tempfile
import rpy2.robjects as robjects
from namedmatrix import NamedMatrix

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers")
c = db.cursor()

# Get all the reports for each id.
query = """
select stitch_id, report_id
from effect_medeffect.report_drug
join effect_medeffect.drug2stitch using (drug_product_id);
"""
c.execute(query)

drug_report = dict()

for stitch_id, isr_report_id in c.fetchall():
    if stitch_id not in drug_report:
        drug_report[stitch_id] = set()
    drug_report[stitch_id].add( isr_report_id )

# Data for the R.apply_cor_test
data = []
keys = []
corr = dict()

drug_cids = sorted([k for k,v in drug_report.items() if len(v) >= 100])
sorted_keys = sorted(drug_report.keys())

print >> sys.stderr, "Building data for R batch.."

# The total number of reports come from this query:
# select count(distinct report_id)
# from report_drug
# join drug2stitch using (drug_product_id);
total_reports = 201224

for i,cid in enumerate(drug_cids):
    if i % 10 == 0:
        print >> sys.stderr, "Building data for drug %d of %d" % (i, len(drug_cids))
    
    correction = 0
    for cid2 in sorted_keys:
        if cid2 == cid:
            continue
        
        x = len(drug_report[cid]&drug_report[cid2])
        if x > 0:
            y = len(drug_report[cid]-drug_report[cid2])
            z = len(drug_report[cid2]-drug_report[cid])
            w = total_reports - len(drug_report[cid] | drug_report[cid2])
            
            data.append( [x, y, z, w] )
            keys.append( [cid, cid2])
            correction += 1
    
    corr[cid] = correction

# Run the correlation tests in R.

# Save the data to file.
tmp_file = tempfile.mkstemp(suffix=".csv")[-1]
tmp_fh = open(tmp_file, 'w')
writer = csv.writer(tmp_fh)
writer.writerows(data)
tmp_fh.close()

out_file = tempfile.mkstemp(suffix=".csv")[-1]
script = """
data <- read.csv('%s', header=F)
run.cor.test <- function(row) {
x = c(rep(1,row[1]), rep(1,row[2]), rep(0,row[3]), rep(0,row[4]))
y = c(rep(1,row[1]), rep(0,row[2]), rep(1,row[3]), rep(0,row[4]))
res <- cor.test(x,y)
return(c( res$p.value, res$estimate ))
}
result <- t(apply(data, 1, run.cor.test))
write.csv(result, file='%s')
""" % (tmp_file, out_file)

robjects.r(script)
reader = csv.reader(open(out_file))
reader.next()
correlations = [map(float,row[1:]) for row in reader]

for (cid,cid2),(pval, estimate) in zip(keys,correlations):
    if cid != cid2:
        q = "INSERT INTO corr_drug_drug (stitch_id1, stitch_id2, pvalue, corrected, estimate, `database`) VALUES ('%s','%s',%.10e,%.10e,%.10f,'medeffect');" % (cid, cid2, pval, min(1,float(pval)*corr[cid]), estimate)
        c.execute(q)

