"""
Discover adverse events significantly associated with drugs.
"""

import os
import csv
import sys
import random
import MySQLdb
import tempfile
import operator
import rpy2.robjects as robjects
from namedmatrix import NamedMatrix

MAX_REPORTS = 120000
DRUG_DRUG_EXP = 200
DRUG_IND_EXP = 200
MIN_BG_REPORTS = 50

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers")
c = db.cursor()

# Define the drug we are interested in.
# cid = 'CID000005090' # Vioxx
cid = sys.argv[1]

print >> sys.stderr, "%s: Building the drug foreground data matrix and file." % cid
# Populate the drug_data
query = """
select report_id, stitch_id, umls_id
from drug_report_event_me
where stitch_id = '%s';
""" % cid
c.execute(query)

report_data = [row for row in c.fetchall()]
report_ids = sorted(set([r[0] for r in report_data]))
umls_ids = sorted(set([r[2] for r in report_data]))

print >> sys.stderr, "%s: Found %d foreground drug reports." % (cid, len(report_ids))

# drug_data = NamedMatrix(None, report_ids, umls_ids)
drug_data = dict()
for report_id, stitch_id, umls_id in report_data:
    if not drug_data.has_key(report_id):
        drug_data[report_id] = set()
    drug_data[report_id].add(umls_id)

# Cleanup
del report_data

drug_data_file = tempfile.mkstemp(suffix=".csv")[-1]
drug_file = open(drug_data_file,'w')
writer = csv.writer(drug_file)
writer.writerow(['RowID'] + umls_ids)
for report_id in map(str, report_ids):
    row = [1 if umls_id in drug_data[report_id] else 0 for umls_id in umls_ids]
    writer.writerow([report_id] + row)

drug_file.close()

print >> sys.stderr, "%s: Building the drug background data matrix and file." % cid
while True:
    # Get the background data.
    query = """
    select report_id, umls_id
    from drug_report_event_me join
    (
    	select stitch_id2 as stitch_id
    	from corr_drug_drug
    	where stitch_id1 = '%s'
    	and corrected < 1e-%d
    	and `database` = 'medeffect'
    ) a
    using (stitch_id);
    """ % (cid, DRUG_DRUG_EXP)
    num_bg_reports = c.execute(query)
    if num_bg_reports < MIN_BG_REPORTS:
        print >> sys.stderr, "%s: Failed to retrieve enough background reports (%d), reducing significance threshold (1e-%d)." % (cid, num_bg_reports, DRUG_DRUG_EXP)
        DRUG_DRUG_EXP /= 2
    else:
        break

drug_report_ids = report_ids
umls_id_set = set(umls_ids)
report_data = [row for row in c.fetchall() if len(set([row[1]])&umls_id_set) > 0]

report_ids = sorted(set([r[0] for r in report_data]) - set(drug_report_ids) )
if len(report_ids) > MAX_REPORTS:
    report_ids = random.sample(report_ids, MAX_REPORTS)

print >> sys.stderr, "%s: Found %d background drug reports." % (cid, len(report_ids))

bg_data = dict()
for report_id, umls_id in report_data:
    if not bg_data.has_key(report_id):
        bg_data[report_id] = set()
    bg_data[report_id].add(umls_id)

# Cleanup
del report_data

bg_data_file = tempfile.mkstemp(suffix=".csv")[-1]
bg_file = open(bg_data_file,'w')
writer = csv.writer(bg_file)
writer.writerow(['RowID'] + umls_ids)
for report_id in map(str,report_ids):
    row = [1 if umls_id in bg_data[report_id] else 0 for umls_id in umls_ids]
    writer.writerow([report_id] + row)

bg_file.close()

# Last bit of cleanup
del drug_data
del bg_data

print >> sys.stderr, "%s: Running the statistical analysis." % cid

result_file = os.path.expanduser("~/Stanford/AltmanLab/aers/medeffect/part2.b/%s.csv" % cid)
r_string = """
bg_data <- read.csv('%s')
rownames(bg_data) <- bg_data[,1]
bg_data <- bg_data[,-1]

drug_data <- read.csv('%s')
rownames(drug_data) <- drug_data[,1]
drug_data <- drug_data[,-1]

results <- matrix(rep(c(NA,NA,NA,NA,NA),ncol(drug_data)), ncol=5)
colnames(results) <- c("pvalue","drug_mean","bg_mean","drug_sd","bg_sd")
rownames(results) <- colnames(drug_data)

nmeans <- 1000
ssize1 <- round(max(10,0.01*nrow(bg_data)))
ssize2 <- round(max(10,0.01*nrow(drug_data)))
write(paste("%s:","nmeans=",nmeans,"sizes:",ssize1,ssize2, "lengths:",nrow(bg_data),nrow(drug_data)), stderr())

for (i in 1:nrow(results)) {
    write(paste("Working on ",i,"of ",nrow(results)), stderr())
	vals1 <- bg_data[,i]
	vals2 <- drug_data[,i]
	
	v1means <- rep(0, nmeans)
	v1means <- unlist(lapply(v1means,function(x){mean(sample(vals1,ssize1))}))
	
	v2means <- rep(0, nmeans)
	v2means <- unlist(lapply(v2means,function(x){mean(sample(vals2,ssize2))}))
	
	results[i,1] <- t.test(v1means, v2means)$p.value
	results[i,2] <- mean(v2means)
	results[i,3] <- mean(v1means)
	results[i,4] <- sd(v2means)
	results[i,5] <- sd(v1means)
}

write.csv(results,file="%s",quote=F)
""" % (bg_data_file, drug_data_file, cid, result_file)

robjects.r(r_string)

print >> sys.stderr, "%s: Completed run, result saved to %s" % (cid, result_file)

os.system("rm %s" % bg_data_file)
os.system("rm %s" % drug_data_file)

reader = csv.reader(open(result_file))
reader.next()
for row in reader:
    q = "INSERT INTO pred_drug_events_b_me (stitch_id, umls_id, pvalue, drug_mean, bg_mean, drug_sd, bg_sd) values ('%s','%s',%s,%s,%s,%s,%s)" % tuple([cid] + row)
    c.execute(q)


