#!/usr/bin/env python
# encoding: utf-8
"""
generate_input_files.py

Created by Nicholas Tatonetti on 2010-09-05.
Copyright (c) 2010 Stanford University. All rights reserved.
"""

import os
import csv
import sys
import random
import MySQLdb
import tempfile
import operator
from StringIO import StringIO
from namedmatrix import NamedMatrix

OUTPUT_DIRECTORY = os.path.expanduser('~/Stanford/AltmanLab/aers/fortran_data/10q4/')
DATABASE = 'aers'

def write_base_files(c):
    
    if DATABASE == 'random':
        query = """
        select stitch_id, report_id, umls_id
        from drug_report_event_random
        join drug_list using (stitch_id)
        where `database` = 'aers'
        """
    else:
        query = """
        select stitch_id, report_id, umls_id
        from drug_report_event%s
        """ % (("" if DATABASE == 'aers' else '_me'))
    c.execute(query)
    
    report_data = [row for row in c.fetchall()]
    
    stitch_ids = set()
    report_ids = set()
    umls_ids = set()
    
    stitch2report = dict()
    report2umls = dict()
    
    for stitch_id, report_id, umls_id in report_data:
        if not stitch_id in stitch2report:
            stitch2report[stitch_id] = set()
        
        stitch2report[stitch_id].add(report_id)
        
        if not report_id in report2umls:
            report2umls[report_id] = set()
        
        report2umls[report_id].add(umls_id)
        
        stitch_ids.add(stitch_id)
        report_ids.add(report_id)
        umls_ids.add(umls_id)
        
    # Fun and safe!
    del report_data
    
    stitch_ids = sorted(stitch_ids)
    report_ids = sorted(report_ids)
    umls_ids = sorted(umls_ids)
    
    # Write *_entity_counts.csv
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_entity_counts.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer.writerow([len(stitch_ids), len(report_ids), len(umls_ids)])
    fh.close()
    
    # Write *_report_index.csv and build report2index dictionary.
    report2index = dict()
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_report_index.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer.writerow(['report_id','fortran_index'])
    for i,rid in enumerate(report_ids):
        writer.writerow([rid,i+1])
        report2index[rid] = i+1
    
    fh.close()
    
    # Write *_umls_index.csv and build umls2index dictionary.
    umls2index = dict()
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_umls_index.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer.writerow(['umls_id','fortran_index'])
    for i,uid in enumerate(umls_ids):
        writer.writerow([uid,i+1])
        umls2index[uid] = i+1
    
    fh.close()
    
    # Write *_stitch_index.csv and build stitch2index dictionary.
    stitch2index = dict()
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_index.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer.writerow(['stitch_id','fortran_index'])
    for i,cid in enumerate(stitch_ids):
        writer.writerow([cid,i+1])
        stitch2index[cid] = i+1
    
    fh.close()
    
    # Write *_report_umlscount.txt and *_report_umlslist.csv
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_report_umlscount.txt' % DATABASE), 'w')
    fh2 = open(os.path.join(OUTPUT_DIRECTORY, '%s_report_umlslist.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer2 = csv.writer(fh2)
    for rid in report_ids:
        writer.writerow([len(report2umls[rid])])
        writer2.writerow([umls2index[uid] for uid in report2umls[rid]])
    
    fh.close()
    fh2.close()
    
    # Write *_stitch_reportcount.txt and *_stitch_reportlist.csv
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_reportcount.txt' % DATABASE), 'w')
    fh2 = open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_reportlist.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer2 = csv.writer(fh2)
    for cid in stitch_ids:
        writer.writerow([len(stitch2report[cid])])
        writer2.writerow([report2index[rid] for rid in stitch2report[cid]])
    
    fh.close()
    fh2.close()
    
    if DATABASE == 'random':
        query = """
        select stitch_id1, stitch_id2, estimate
        from corr_drug_drug_random
        where estimate > 0.001
        """
    else:
        query = """
        select stitch_id1, stitch_id2, estimate
        from corr_drug_drug
        where estimate > 0.001
        and `database` = '%s'
        """ % DATABASE
    c.execute(query)
    
    stitch_stitch = NamedMatrix(None, stitch_ids, stitch_ids)
    for stitch_id1, stitch_id2, estimate in c.fetchall():
        if stitch_id1 in stitch_ids and stitch_id2 in stitch_ids:
            index1 = stitch_ids.index(stitch_id1)
            index2 = stitch_ids.index(stitch_id2)
            stitch_stitch[index1, index2] = estimate
    
    stitch_stitch[stitch_ids, stitch_ids].save_to_file(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_stitch_est.csv' % DATABASE), False, False)
    
    if DATABASE == 'random':
        query = """
        select stitch_id, indication, estimate
        from corr_drug_ind_random   
        where estimate > 0.001
        """
    else:
        query = """
        select stitch_id, indication, estimate
        from corr_drug_ind
        where `database` = '%s'
        and estimate > 0.001
        """ % DATABASE
    c.execute(query)
    data = c.fetchall()
    indications = sorted(set([x[1] for x in data]))
    stitch_ind = NamedMatrix(None, stitch_ids, indications)
    
    for stitch_id, indication, estimate in data:
        if stitch_id in stitch_ids and indication in indications:
            index1 = stitch_ids.index(stitch_id)
            index2 = indications.index(indication)
            stitch_ind[index1, index2] = estimate
    
    stitch_ind[stitch_ids, :].save_to_file(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_ind_est.csv' % DATABASE), False, False)
    
    query = """
    select indication, report_id
    from indication_report_event
    """
    c.execute(query)
    data = c.fetchall()
    
    ind2report = dict()
    
    for ind, rep in data:
        if not ind in ind2report:
            ind2report[ind] = set()
        ind2report[ind].add(str(rep))
    
    reports = set(report_ids)
    for ind, reps in ind2report.items():
        ind2report[ind] = reps & reports
    
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_indication_index.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    writer.writerow(['indication','fortran_index'])
    for index,ind in enumerate(indications):
        writer.writerow([ind.strip(), index+1])
    fh.close()
    
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_ind_reportcount.txt' % DATABASE), 'w')
    writer = csv.writer(fh)
    for index,ind in enumerate(indications):
        writer.writerow([len(ind2report[ind])])
    fh.close()
    
    fh = open(os.path.join(OUTPUT_DIRECTORY, '%s_ind_reportlist.csv' % DATABASE), 'w')
    writer = csv.writer(fh)
    for ind in indications:
        reps = ind2report[ind]
        writer.writerow([report2index[r] for r in reps])
    fh.close()
    
    #return ((stitch_ids, stitch2report, stitch2index), (report_ids, report2umls, report2index), (umls_ids, umls2index))


def load_from_disk():
    
    # Load up index data.
    stitch_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_index.csv' % DATABASE)))][1:]
    report_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_report_index.csv' % DATABASE)))][1:]
    umls_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_umls_index.csv' % DATABASE)))][1:]
    
    stitch2index = dict()
    index2stitch = dict()
    for i,cid in enumerate(stitch_ids):
        stitch2index[cid] = i+1
        index2stitch[i+1] = cid
    
    report2index = dict()
    index2report = dict()
    for i,rid in enumerate(report_ids):
        report2index[rid] = i+1
        index2report[i+1] = rid
    
    umls2index = dict()
    index2umls = dict()
    for i,uid in enumerate(umls_ids):
        umls2index[uid] = i+1
        index2umls[i+1] = uid
    
    stitch2report = dict()
    for i,report_indices in enumerate(csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_reportlist.csv' % DATABASE)))):
        stitch2report[index2stitch[i+1]] = set([index2report[int(report_index)] for report_index in report_indices])
    
    report2umls = dict()
    for i,umls_indices in enumerate(csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_report_umlslist.csv' % DATABASE)))):
        report2umls[index2report[i+1]] = set([index2umls[int(umls_index)] for umls_index in umls_indices])
    


if __name__ == '__main__':
    
    # Another attempt.
    c = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="enter_your_password",db="project_aers_10q4").cursor()
    
    write_base_files(c)
    