#!/usr/bin/env python
# encoding: utf-8
"""
insert_results.py

Created by Nicholas Tatonetti on 2010-09-08.
Copyright (c) 2010 Stanford University. All rights reserved.

If you ran using the Obama cluster you may want to concatinate/sanitize the inputs

for f in *.txt
do
python -c "import sys; print sys.stdin.read().split('The output (if any) follows:\n\n')[1].split('\n\n\nPS')[0]" < $f >> concatenated.txt
done

"""

import os
import csv
import sys
import MySQLdb

# Choose your settings.

OUTPUT_DIRECTORY = os.path.expanduser('~/Stanford/AltmanLab/aers/fortran_data/EMR/')
# OUTPUT_DIRECTORY = os.path.expanduser('~/Stanford/AltmanLab/aers/fortran_data/')
# OUTPUT_DIRECTORY = os.path.expanduser('~/Stanford/AltmanLab/aers/fortran_data/10q4/')

DATABASE = 'emr'
# DATABASE = 'aers'
# DATABASE = 'medeffect'
# DATABASE = 'random'

SQL_DB = 'effect_stanford'
# SQL_DB = 'project_aers'
# SQL_DB = 'project_aers_10q4'

TABLE_NAME = 'pred_drug_events_e'
# TABLE_NAME = 'pred_drug_events_e5'

PAIR_TABLE_NAME = 'pred_pair_events_e'
# PAIR_TABLE_NAME = 'pred_pair_events_e5'

if __name__ == '__main__':
    
    # c = MySQLdb.connect(host="127.0.0.1", port=3307, user="root", passwd="enter_your_password",db=SQL_DB).cursor()
    c = MySQLdb.connect(host="localhost", port=3306, user="root", passwd="enter_your_password",db=SQL_DB).cursor()
    
    result_directory = os.path.join(OUTPUT_DIRECTORY, 'results')
    
    print >> sys.stderr, "Loading index dictionaries from %s" % OUTPUT_DIRECTORY
    
    stitch_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_stitch_index.csv' % DATABASE)))][1:]
    report_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_report_index.csv' % DATABASE)))][1:]
    umls_ids = [x[0] for x in csv.reader(open(os.path.join(OUTPUT_DIRECTORY, '%s_umls_index.csv' % DATABASE)))][1:]
    
    stitch2index = dict()
    index2stitch = dict()
    for i,cid in enumerate(stitch_ids):
        stitch2index[cid] = i+1
        index2stitch[i+1] = cid
    
    report2index = dict()
    index2report = dict()
    for i,rid in enumerate(report_ids):
        report2index[rid] = i+1
        index2report[i+1] = rid
    
    umls2index = dict()
    index2umls = dict()
    for i,uid in enumerate(umls_ids):
        umls2index[uid] = i+1
        index2umls[i+1] = uid
    
    data = [line.strip().split() for line in sys.stdin.readlines()]
    notification = False
    perc_complete = 0
    
    for i,row in enumerate(data):
        
        if (i+1)%(len(data)/20) == 0:
            perc_complete += 1
            print >> sys.stderr, "..%d%s.." % (perc_complete*(100/20), '%'),
        
        if len(row) == 10:
            # We are inserting single drug-adverse event assocations
            if not notification:
                print >> sys.stderr, "Inserting single assocations."
                notification = True
            
            args = tuple( [TABLE_NAME, index2stitch[int(row[0])], index2umls[int(row[1])]] + row[2:] )
            #print args
            q = """
            INSERT IGNORE INTO %s (stitch_id, umls_id, pvalue, drug_mean, drug_sd, bg_mean, bg_sd, drug_num, bg_num, drug_cutoff, ind_cutoff)
            VALUES ('%s','%s',NULL,%s,%s,%s,%s,%s,%s,%s,%s)
            """ % args
            c.execute(q)
        elif len(row) == 9:
            # We are inserting pair drug-adverse event associations.
            if not notification:
                print >> sys.stderr, "Inserting paired assocations."
                notification = True
            
            args = tuple( [PAIR_TABLE_NAME, index2stitch[int(row[0])], index2stitch[int(row[1])], index2umls[int(row[2])]] + row[3:])
            q = """
            INSERT IGNORE INTO %s (stitch_id1, stitch_id2, umls_id, pvalue, pair_mean, pair_sd, bg_mean, bg_sd, pair_num, bg_num)
            VALUES ('%s','%s','%s',NULL,%s,%s,%s,%s,%s,%s)
            """ % args
            c.execute(q)
    
    print >> sys.stderr, ""

