"""
enriched_indications.py
@author Nicholas P. Tatonetti
@copyright 2010 Stanford University
=================

Define a drug set based on a given Indication descrip term. We say that drugs "treat" the indicaiton
when they are significantly enriched (according to a chi-squared contingency table) for the indication.

Here's the general algorithm.

for each drug:
                 Indication             !Indication
                +---------------------+---------------------+
    Drug Set    | Num Reports with    | Num Reports w/ Drug |
                | Drug & Indication   | and not Indication  |
                +---------------------+---------------------+
    !Drug Set   | Num Reports with    | Num Reports w/o     |
                | Indication & !Drug  | Drug or Indicaiton  |
                +---------------------+---------------------+
"""

import os
import sys
import csv
import random
import MySQLdb
import operator
import rpy2.robjects as robjects

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="project_aers")

class EnrichedIndications(object):
    """
    Class for finding enriched indications.
    """
    _indication_reports = None
    _singlet_report_ids = None
    _db_conn = None
    
    @staticmethod
    def get_indication_report_dictionary():
        if EnrichedIndications._indication_reports is None:
            # Get report ids for reports with indiaction.
            query = """
            select indication_descrip_term, isr_report_id
            from effect_aers.singlet_indications
            """
            
            EnrichedIndications._db_conn.execute(query)
            EnrichedIndications._indication_reports = dict()
            
            for row in EnrichedIndications._db_conn.fetchall():
                ind, rid = row
                if not EnrichedIndications._indication_reports.has_key(ind):
                    EnrichedIndications._indication_reports[ind] = set()
                EnrichedIndications._indication_reports[ind].add(rid)
        
        return EnrichedIndications._indication_reports
    
    @staticmethod
    def get_drug_report_ids(aers_drug_names, temp_table):
        # Get report ids for the given drugs.
        query = """
        select isr_report_id
        from effect_aers.%s
        join effect_aers.singlet_drugs on (drug_name = id)
        join effect_aers.singlet_indications using (isr_report_id)
        """ % temp_table
        EnrichedIndications._db_conn.execute(query)
        return set([r[0] for r in EnrichedIndications._db_conn.fetchall()])
    
    @staticmethod
    def get_indication_report_ids(aers_indication_names, temp_table):
        # Get report ids for the given indications.
        query = """
        select isr_report_id
        from effect_aers.%s
        join effect_aers.singlet_indications on (indication_descrip_term = id)
        """ % temp_table
        EnrichedIndications._db_conn.execute(query)
        return set([r[0] for r in EnrichedIndications._db_conn.fetchall()])
    
    @staticmethod
    def get_singlet_report_ids():
        if EnrichedIndications._singlet_report_ids is None:
            # Get a list of all the singlet report ids.
            query = """
            select isr_report_id
            from effect_aers.singlet_indications;
            """
            EnrichedIndications._db_conn.execute(query)
            EnrichedIndications._singlet_report_ids = set([r[0] for r in EnrichedIndications._db_conn.fetchall()])
    
        return EnrichedIndications._singlet_report_ids
    
    @staticmethod
    def find_correlated_indications(drug_report_ids):
        
        print >> sys.stderr, "Finding enriched indications for the drug set."
        indication_reports = EnrichedIndications.get_indication_report_dictionary()
        
        print >> sys.stderr, "Extracting all singlet report ids."
        singlet_report_ids = EnrichedIndications.get_singlet_report_ids()
        
        results = []
        for i, (indication_term, report_ids) in enumerate(indication_reports.items()):
            if i % 100 == 0:
                print >> sys.stderr, ".",
            
            both = len(report_ids & drug_report_ids)
            indication = len(report_ids - drug_report_ids)
            drugs = len(drug_report_ids - report_ids)
            neither = len( singlet_report_ids - (report_ids | drug_report_ids) )
            
            if not both == 0:# and ratio > 1:
                if both < 30 or indication < 30:
                    result = robjects.r("fisher.test(matrix(c(%d, %d, %d, %d), nrow=2))" % (both, drugs, indication, neither))
                else:
                    result = robjects.r("chisq.test(matrix(c(%d, %d, %d, %d), nrow=2))" % (both, drugs, indication, neither))
                
                pvalue = result[list(result.names).index('p.value')][0]
            else:
                pvalue = 1.0
            
            results.append([indication_term, both, indication, drugs, neither, pvalue])
        
        correlated_indications = [row for row in results if row[-1] <= 0.05]
        
        return correlated_indications

def main(aers_names, drug_based=True):
    
    temp_table = 'temp_ids%d' % random.randint(0,1000)
    
    print >> sys.stderr, "Creating a tempory table (%s) to hold identifiers." % temp_table
    EnrichedIndications._db_conn.execute("drop table if exists effect_aers.%s;" % temp_table)
    query = """
    CREATE TABLE effect_aers.%s (
      `id` varchar(50) NOT NULL DEFAULT '',
      PRIMARY KEY (`id`)
    ) ENGINE=MyISAM DEFAULT CHARSET=latin1;
    """ % temp_table
    EnrichedIndications._db_conn.execute(query)
    for name in aers_names:
        EnrichedIndications._db_conn.execute("insert into effect_aers.%s values (\"%s\")" % (temp_table, name))
    
    if drug_based:
        print >> sys.stderr, "Extracting drug report ids."
        report_ids = EnrichedIndications.get_drug_report_ids(aers_names, temp_table)
    else:
        print >> sys.stderr, "Extracting indication report ids."
        report_ids = EnrichedIndications.get_indication_report_ids(aers_names, temp_table)
    
    print >> sys.stderr, "Finding correlated indications."
    correlated_indications = EnrichedIndications.find_correlated_indications(report_ids)
    
    print >> sys.stderr, "Dropping temporary table."
    EnrichedIndications._db_conn.execute("drop table effect_aers.%s;" % temp_table)
    
    return correlated_indications

if __name__ == '__main__':
    
    indication_name = sys.argv[1]
    
    EnrichedIndications._db_conn = db.cursor()
    
    drug_based = os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt'))
    indication_based = os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_indication_names.txt'))
    
    if not drug_based and not indication_based:
        raise Exception("Either aers_drug_names.txt or aers_indication_names.txt must be present in the working directory.")
    
    if drug_based:
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')))]
        print >> sys.stderr, 'Loaded drug names for set: %s, found %d' % (indication_name, len(aers_names))
    else:
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_indication_names.txt')))]
        print >> sys.stderr, 'Loaded indication names for set: %s, found %d' % (indication_name, len(aers_names))
    
    correlated_indications = main(aers_names, drug_based)
    
    # check if data directory path exists
    if not os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name)):
        os.mkdir(os.path.join(PATH_TO_DATA_DIR, indication_name))
    
    outfh = open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'correlated_indications.csv'), 'w')
    csv.writer(outfh).writerows(correlated_indications)
    outfh.close()
