"""
Make the input files for the EMR "adverse event reports"

"""

import os
import csv
import sys
from namedmatrix import NamedMatrix

OUTPUT_DIRECTORY = os.path.expanduser('~/Stanford/AltmanLab/aers/fortran_data/EMR/')

print >> sys.stderr, "Loading reports..."

mapped_reports = os.path.join(OUTPUT_DIRECTORY, 'raw_mapped_reports.txt')

reports = sorted(set([x[0] for x in csv.reader(open(mapped_reports), delimiter='\t')]))

rep2index = dict()
for i,repid in enumerate(reports):
    rep2index[repid] = i+1

print >> sys.stderr, "Loading indications..."

mapped_indications = os.path.join(OUTPUT_DIRECTORY, 'raw_mapped_indications.txt')

ind2rep = dict()
indications = set()

for icd9, repid in csv.reader(open(mapped_indications), delimiter='\t'):
    if not icd9 in ind2rep:
        ind2rep[icd9] = set()
    
    ind2rep[icd9].add(rep2index[repid])
    indications.add(icd9)

indications = sorted(indications)

ind2index = dict()
for i,indication in enumerate(indications):
    ind2index[indication] = i+1

print >> sys.stderr, "Writing indication data..."

ind_reportcount = open(os.path.join(OUTPUT_DIRECTORY, 'emr_ind_reportcount.txt'), 'w')
ind_reportlist = open(os.path.join(OUTPUT_DIRECTORY, 'emr_ind_reportlist.csv'), 'w')
indication_index = open(os.path.join(OUTPUT_DIRECTORY, 'emr_indication_index.csv'), 'w')

print >> indication_index, "indication,fortran_index"

for i,ind in enumerate(indications):
    print >> ind_reportcount, "%d" % len(ind2rep[ind])
    print >> ind_reportlist, ",".join(map(str, ind2rep[ind]))
    print >> indication_index, "%s,%d" % (ind, i+1)

ind_reportlist.close()
ind_reportcount.close()
indication_index.close()

print >> sys.stderr, "Loading reactions..."

mapped_reactions = os.path.join(OUTPUT_DIRECTORY, 'raw_mapped_reactions.txt')

rep2event = dict()
events = set()
for repid, uid in csv.reader(open(mapped_reactions), delimiter='\t'):
    if not repid in rep2event:
        rep2event[repid] = set()
    
    rep2event[repid].add(uid)
    events.add(uid)

events = sorted(events)
event2index = dict()
for i,event in enumerate(events):
    event2index[event] = i+1

print >> sys.stderr, "Writing reaction data..."

report_umlscount = open(os.path.join(OUTPUT_DIRECTORY, 'emr_report_umlscount.txt'), 'w')
report_umlslist = open(os.path.join(OUTPUT_DIRECTORY, 'emr_report_umlslist.csv'), 'w')
report_index = open(os.path.join(OUTPUT_DIRECTORY, 'emr_report_index.csv'), 'w')

print >> report_index, "report,fortran_index"

for repid in reports:
    print >> report_index, "%s,%d" % (repid, rep2index[repid])
    print >> report_umlscount, "%d" % len(rep2event[repid])
    print >> report_umlslist, ",".join(map(str, [event2index[uid] for uid in rep2event[repid]]))

report_index.close()
report_umlslist.close()
report_umlscount.close()

event_index = open(os.path.join(OUTPUT_DIRECTORY, 'emr_umls_index.csv'), 'w')
print >> event_index, "event,fortran_index"

for i, event in enumerate(events):
    print >> event_index, "%s,%d" % (event, i+1)

event_index.close()

print >> sys.stderr, "Loading drugs..."

mapped_drugs = os.path.join(OUTPUT_DIRECTORY, 'raw_mapped_drugs.txt')

drug2rep = dict()
drugs = set()

for repid, stitch_id in csv.reader(open(mapped_drugs), delimiter='\t'):
    if not stitch_id in drug2rep:
        drug2rep[stitch_id] = set()
    
    drug2rep[stitch_id].add(rep2index[repid])
    drugs.add(stitch_id)

drugs = sorted(drugs)

drug2index = dict()
for i,drug in enumerate(drugs):
    drug2index[drug] = i+1

print >> sys.stderr, "Writing drug data..."

drug_reportcount = open(os.path.join(OUTPUT_DIRECTORY, 'emr_stitch_reportcount.txt'), 'w')
drug_reportlist = open(os.path.join(OUTPUT_DIRECTORY, 'emr_stitch_reportlist.csv'), 'w')
drug_index = open(os.path.join(OUTPUT_DIRECTORY, 'emr_stitch_index.csv'),'w')

print >> drug_index, "drug,fortran_index"

for i,drug in enumerate(drugs):
    print >> drug_index, "%s,%d" % (drug, i+1)
    print >> drug_reportcount, "%d" % len(drug2rep[drug])
    print >> drug_reportlist, ",".join(map(str, drug2rep[drug]))

drug_reportlist.close()
drug_reportcount.close()
drug_index.close()

entity_counts = open(os.path.join(OUTPUT_DIRECTORY, 'emr_entity_counts.csv'), 'w')
print >> entity_counts, "%d,%d,%d,%d" % (len(drugs), len(reports), len(events), len(indications))
entity_counts.close()

print >> sys.stderr, "Loading drug-drug correlation data.."

drug_drug_corr = os.path.join(OUTPUT_DIRECTORY, 'raw_drug_drug_corr_est.txt')

ddcorr = NamedMatrix(None, drugs, drugs)

for stitch_id1, stitch_id2, estimate in csv.reader(open(drug_drug_corr), delimiter='\t'):
    
    drug1_index = drug2index[stitch_id1] - 1
    drug2_index = drug2index[stitch_id2] - 1
    
    ddcorr[drug1_index,drug2_index] = estimate

ddcorr.save_to_file('emr_stitch_stitch_est.csv', row_labels=False, column_labels=False)

print >> sys.stderr, "Loading drug-ind correlation data..."

drug_ind_corr = os.path.join(OUTPUT_DIRECTORY, 'raw_drug_ind_corr_est.txt')

dicorr = NamedMatrix(None, drugs, indications)

for stitch_id, indication, estimate in csv.reader(open(drug_ind_corr), delimiter='\t'):
    
    drug_index = drug2index[stitch_id] - 1
    ind_index = ind2index[indication] - 1
    
    dicorr[drug_index, ind_index] = estimate

dicorr.save_to_file('emr_stitch_ind_est.csv', row_labels=False, column_labels=False)
