"""
Adapted version of Propensity Score Matching to identify matched control reports.

Algorithm:
For each drug, d:
    1. Find sig correlated drugs and indications (using pre-calculated matrices).
    2. Order the drugs and indications, "features," by their correlation statistic.
    3. Using all significant (after correction) features fit a logistic regression.
        a. Y = 1{report lists d}
    4. Fit a model using half of the features.
        a. Test against full model:
            1) if no significant different repeat 4
            2) if significantly different, use previous model and exit
    5. Use fitted logistic regression model to assign scores to each report
        Note: only reports that have a non-zero value for at least one of the features are scored
    6. Store these scores for later use in the matching algorithm.

"""

import os
import sys
import csv
import MySQLdb
import random

import rpy2.robjects as robjects

print >> sys.stderr, "Running psm.py version 0.5"

# c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers").cursor()

c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers_10q4").cursor()

query = """
select stitch_id
from psm_models
"""
c.execute(query)
completed_drugs = [r[0] for r in c.fetchall()]

query = """
select stitch_id, indication, estimate
from corr_drug_ind
where `database` = 'aers'
and estimate > 0
# and corrected < 0.05;
"""
c.execute(query)

drug_ind_corr = dict()

for sid, ind, est in c.fetchall():
    if not sid in drug_ind_corr:
        drug_ind_corr[sid] = list()
    
    drug_ind_corr[sid].append((ind, est))

query = """
select stitch_id1, stitch_id2, estimate
from corr_drug_drug
where `database` = 'aers'
and estimate > 0
# and corrected < 0.05;
"""
c.execute(query)

drug_drug_corr = dict()

for sid1, sid2, est in c.fetchall():
    if not sid in drug_drug_corr:
        drug_drug_corr[sid1] = list()
    
    drug_drug_corr[sid1].append((sid2, est))

drugs = sorted(set(drug_drug_corr.keys()) & set(drug_ind_corr.keys()))

query = """
select stitch_id, report_id
from drug_report
"""
c.execute(query)

drug2report = dict()
report2drug = dict()
for sid, rid in c.fetchall():
    if not sid in drug2report:
        drug2report[sid] = set()
    
    if not rid in report2drug:
        report2drug[rid] = set()
    
    drug2report[sid].add(rid)
    report2drug[rid].add(sid)

query = """
select indication, report_id
from indication_report
"""
c.execute(query)

ind2report = dict()
report2ind = dict()
for ind, rid in c.fetchall():
    if not ind in ind2report:
        ind2report[ind] = set()
    
    if not rid in report2ind:
        report2ind[rid] = set()
    
    ind2report[ind].add(rid)
    report2ind[rid].add(ind)

outfh = open('results.csv', 'w')
writer = csv.writer(outfh)

outfh2 = open('models.csv','w')
writer2 = csv.writer(outfh2)

for drugi, drug in enumerate(drugs):
    
    if drug in completed_drugs:
        continue
    
    print >> sys.stderr, "Working on drug %s, %d of %d" % (drug, drugi, len(drugs))
    
    features = [(r[1],r[0],'drug') for r in drug_drug_corr[drug] if r[0] != drug]
    features += [(r[1],r[0],'ind') for r in drug_ind_corr[drug]]
    features = sorted(features) # sort by the correlation value
    
    if len(features) > 200:
        features = features[:200]
    
    case_reports = drug2report[drug]
    
    control_reports = set()
    for ext, feature, feature_type in features:
        if feature_type == 'drug':
            control_reports |= (drug2report[feature] - case_reports)
        else:
            control_reports |= (ind2report[feature] - case_reports)
    
    if len(case_reports) > 2e4:
        case_reports = random.sample(case_reports, int(2e4))
    
    if len(control_reports) > 1e5:
        control_reports = random.sample(control_reports, int(1e5))
    
    case_reports = sorted(case_reports)
    control_reports = sorted(control_reports)
    
    reports = case_reports + control_reports
    
    columns = list()
    for est, feature, feature_type in features:
        
        if feature_type == 'drug':
            columns.append("c(%s)" % ','.join(['1' if feature in report2drug.get(rid,set()) else '0' for rid in reports]))
        
        if feature_type == 'ind':
            columns.append("c(%s)" % ','.join(['1' if feature in report2ind.get(rid,set()) else '0' for rid in reports]))
    
    Y = "c(%s)" % ",".join(["1"]*len(case_reports) + ["0"]*len(control_reports))
    
    data = "cbind(%s)" % ",".join(columns)
    
    num_features = len(columns)
    
    regression = " + ".join(["data$V%d" % (i+1) for i in range(num_features)])
    
    rcode = """
    data <- as.data.frame(%s)
    Y <- %s
    fit <- glm(Y ~ %s, family=binomial)
    """ % (data, Y, regression)
    robjects.r(rcode)
    
    # while True:
    #     num_features = int(num_features/2)
    #     regression = " + ".join(["data$V%d" % (i+1) for i in range(num_features)])
    #     
    #     rcode = """
    #     fit1 <- lm(Y ~ %s, family=binomial(link="logit"))
    #     """ % (regression)
    #     robjects.r(rcode)
    #     
    #     anova_p = robjects.r("anova(fit, fit1)$'Pr(>F)'")[1]
    #     
    #     if anova_p < 0.05 or num_features <= 10:
    #         break
    #     else:
    #         robjects.r("fit <- fit1")
    
    robjects.r("scores <- predict(fit)")
    
    robjects.r("write.csv(scores, file='/tmp/scores.csv', quote=F)")
    
    report_scores = [row[1] for row in csv.reader(open('/tmp/scores.csv'))][1:]
    
    writer2.writerow([drug, len(columns), len(case_reports), len(control_reports)])
    
    for score, rid in zip(report_scores, reports):
        writer.writerow([drug, rid, score])
    
outfh.close()
outfh2.close()

# Append the cohort to the results.csv file.

reader = csv.reader(open('results.csv'))

outfh = open('results_plus_cohort.csv','w')
writer = csv.writer(outfh)

for sid, rid, score in reader:
    if rid in drug2report[sid]:
        writer.writerow([sid, rid, score, 1])
    else:
        writer.writerow([sid, rid, score, 0])

outfh.close()
