"""
Computes the correlated indications for the drugs.
"""

import csv
import sys
import numpy
import MySQLdb
import rpy2.robjects as robjects
from namedmatrix import NamedMatrix

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers")
c = db.cursor()

# Get all the reports for each drug.
query = """
select stitch_id, report_id
from effect_medeffect.report_drug
join effect_medeffect.drug2stitch using (drug_product_id);
"""
c.execute(query)

drug_report = dict()

for stitch_id, isr_report_id in c.fetchall():
    if stitch_id not in drug_report:
        drug_report[stitch_id] = set()
    drug_report[stitch_id].add( isr_report_id )



query = """

"""



drug_ind_repcount = dict()
drug_repcount = dict()
ind_repcounts = dict()

for row in c.fetchall():
    cid, ind, count = row
    if not drug_ind_repcount.has_key(cid):
        drug_ind_repcount[cid] = dict()
    drug_ind_repcount[cid][ind] = int(count)
    
    if not drug_repcount.has_key(cid):
        drug_repcount[cid] = 0
    drug_repcount[cid] += int(count)
    
    if not ind_repcounts.has_key(ind):
        ind_repcounts[ind] = 0
    ind_repcounts[ind] += int(count)

indications = sorted(ind_repcounts.keys())
total_reports = sum(ind_repcounts.values())

data = []
keys = []
corr = dict()

print >> sys.stderr, "Building data for R batch.."

for cid in drug_cids:
    correction = 0
    for ind in indications:
        num_drg_wi_ind = drug_ind_repcount[cid].get(ind,0)      # num reports with both drug and indication
        if num_drg_wi_ind > 0:
            num_drg_wo_ind = drug_repcount[cid] - num_drg_wi_ind    # num reports with drug and not indication
            num_ind_wo_drg = ind_repcounts[ind]- drug_ind_repcount[cid][ind] # Num reports with ind and not drug
            num_wo_ind_drg = total_reports - drug_repcount[cid] - ind_repcounts[ind] + drug_ind_repcount[cid][ind]
            
            data.append([num_drg_wi_ind, num_drg_wo_ind, num_ind_wo_drg, num_wo_ind_drg])
            keys.append([cid, ind])
            correction += 1
            # x = numpy.matrix('%d %d; %d %d' % (num_drg_wi_ind, num_drg_wo_ind, num_ind_wo_drg, num_wo_ind_drg))
            # pval = fisher_test(x)
            # results.append( (pval, ind) )
    
    corr[cid] = correction

pvalues = apply_fisher_test(data)

outfh = open('significant_indications_batch.csv','w')

for (cid,ind),pval in zip(keys,pvalues):
    print >> outfh, "%s,%s,%s,%e" % (cid, ind, pval, min(1,float(pval)*corr[cid]))
    # writer.writerow([cid, ind, pval, min(1,float(pval)*corr[cid])])

outfh.close()

