"""
Create the three additional files needed to run discover_ae_p4_pairs.f90

FILES TO GENERATE
=================
aers_pair2stitch.txt
aers_pair_stitch_est.csv
aers_pair_ind_est.csv
aers_pair_index.csv
"""

import os
import sys
import csv
import MySQLdb
from namedmatrix import NamedMatrix

# Read in the index map.
stitch2index = dict()

for cid, index in csv.reader(open('aers_stitch_index.csv')):
    try:
        stitch2index[cid] = int(index)
    except:
        pass


indication2index = dict()
for indication, index in csv.reader(open('aers_indication_index.csv')):
    indication2index[indication] = int(index)

# read in pairs
pair2index = dict()
index = 1

fh1 = open('aers_pairs2stitch.txt','w')
writer1 = csv.writer(fh1, delimiter='\t')

fh2 = open('aers_pair_index.csv','w')
writer2 = csv.writer(fh2, delimiter=',')
writer2.writerow(["drug1", "drug2", "index"])

for cid1, cid2 in csv.reader(open('aers_pairs.txt'), delimiter='\t'):
    
    if not cid1 in stitch2index or not cid2 in stitch2index or cid1 == cid2:
        # print >> sys.stderr, "One of identifiers not in data (%s, %s)." % (cid1, cid2)
        continue
    
    key = '%s,%s' % (cid1, cid2)
    pair2index[key] = index
    
    writer1.writerow([stitch2index[cid1], stitch2index[cid2]])
    writer2.writerow([cid1, cid2, index])
    
    index += 1

fh1.close()
fh2.close()

# Load in the estimates data
c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers").cursor()

query = """
select pair1, pair2, other, estimate
from corr_pair_drug
"""
c.execute(query)
data = c.fetchall()

corr_pair_drug = NamedMatrix(None, map(str, range(len(pair2index))), map(str, range(len(stitch2index))))

for pair1, pair2, other, estimate in data:
    
    key = '%s,%s' % (pair1, pair2)
    if not key in pair2index or not other in stitch2index:
        continue
    
    pair_index = pair2index[key]-1
    other_index = int(stitch2index[other])-1
    
    corr_pair_drug[pair_index, other_index] = float(estimate)

corr_pair_drug.save_to_file('aers_pair_stitch_est.csv', row_labels=False, column_labels=False)

# Load in the estimate data for indications

query = """
select pair1, pair2, indication, estimate
from corr_pair_ind
"""
c.execute(query)
data = c.fetchall()

corr_pair_ind = NamedMatrix(None, map(str, range(len(pair2index))), map(str, range(len(indication2index))))

for pair1, pair2, indication, estimate in data:
    
    key = '%s,%s' % (pair1, pair2)
    if not key in pair2index or not indication in indication2index:
        continue
    
    pair_index = pair2index[key]-1
    ind_index = int(indication2index[indication])-1
    
    corr_pair_ind[pair_index, ind_index] = float(estimate)

corr_pair_ind.save_to_file('aers_pair_ind_est.csv', row_labels=False, column_labels=False)

