"""
identify_features.py

Derive an Adverse Event profile for a set of drugs
by finding AE's which are enriched when compared to
a background of correlated indications.
"""

import os
import sys
import csv
import numpy
import MySQLdb
import operator
import rpy2.robjects as robjects

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="effect_aers")

# To replicate the features to find pravastatin + paroxetine combo:
# CONTINGENCY_FREQ_CUTOFF = 0.01
# USE_WILCOX_TEST = False
# MULT_HYPT_CORRECT = False
# Take top 7 features.

INDICATION_MINIMUM_REPORTS = 10
DRUG_MIMIMUM_REPORTS = 10
CONTINGENCY_FREQ_CUTOFFS = [0.01, 0.025, 0.05, 0.075, 0.1]
EVENT_PVALUE_CUTOFF = 0.05

USE_WILCOX_TEST = False
MULT_HYPT_CORRECT = False

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

class IdentifyFeatures(object):
    """
    Static class with methods to identify features for the drug set and given indications.
    """
    
    _db_conn = None
    
    @staticmethod
    def wilcox_test(event, inds, drugs, event_drug_freqs, event_ind_freqs):
        
        pvalue = 1.0
        
        drug_freqs = [event_drug_freqs[event].get(drug_name, 0.0) for drug_name in drugs]
        inds_freqs = [event_ind_freqs[event].get(ind_name, 0.0) for ind_name in inds]
        
        mean_drug_freq = numpy.mean(drug_freqs)
        mean_inds_freq = numpy.mean(inds_freqs)
        
        if mean_drug_freq > mean_inds_freq:
            # The event is more common in the drug set when compared to the indicaiton set
            result = robjects.r('wilcox.test(c(%s), c(%s))' % (",".join(map(str,drug_freqs)), ",".join(map(str,inds_freqs))))
            pvalue = result[list(result.names).index('p.value')][0]
        
        return pvalue
    
    @staticmethod
    def fisher_test(event, inds, drugs, event_drug_freqs, event_ind_freqs, cutoff):
        drug_cutoff = [1 if event_drug_freqs[event].get(drug_name,0) > cutoff else 0 for drug_name in drugs]
        ind_cutoff = [1 if event_ind_freqs[event].get(ind_name,0) > cutoff else 0 for ind_name in inds]
        
        pvalue = 1.0
        if sum(drug_cutoff) > 0:
            
            a11 = sum(drug_cutoff)
            a12 = len(drug_cutoff)-a11
            a21 = sum(ind_cutoff)
            a22 = len(ind_cutoff)-a21
            
            result = robjects.r('fisher.test(matrix(c(%d,%d,%d,%d), nrow=2))' % (a11, a12, a21, a22))
            pvalue = result[list(result.names).index('p.value')][0]
        
        return pvalue
    
    @staticmethod
    def get_event_indication_frequencies(indications, min_reports = 10):
        query = """
        select indication_descrip_term, event_descrip_term, count(distinct(isr_report_id)) as num_reports, total_singlet_reports
        from effect_aers.singlet_indications
        join effect_aers.report_indcount using (isr_report_id)
        join effect_aers.singlet_reactions using (isr_report_id)
        join 
        (
            select indication_descrip_term, count(distinct(isr_report_id)) as total_singlet_reports
            from effect_aers.singlet_indications
            join effect_aers.report_indcount using (isr_report_id)
            where indication_descrip_term in (%s)
            and num_indications = 1
            group by indication_descrip_term
        ) a using (indication_descrip_term)
        where indication_descrip_term in (%s)
        and num_indications = 1
        group by indication_descrip_term, event_descrip_term;
        """ % tuple([",".join(map(lambda x: '"%s"' % x, indications))]*2)
        IdentifyFeatures._db_conn.execute(query)
        data = [row for row in IdentifyFeatures._db_conn.fetchall()]
        
        event_ind_freqs = dict()
        
        for row in data:
            ind_term = row[0]
            ae_term = row[1]
            freq = float(row[2])/float(row[3])
            if int(row[3]) >= min_reports:
                
                if not event_ind_freqs.has_key(ae_term):
                    event_ind_freqs[ae_term] = dict()
                    
                if not event_ind_freqs[ae_term].has_key(ind_term):
                    event_ind_freqs[ae_term][ind_term] = freq
        
        return event_ind_freqs
    
    @staticmethod
    def get_event_drug_frequencies_cached(aers_drug_names, min_reports = 10):
        
        query = """
        select singlet, event, freq
        from project_cocktails.ae_singlet_drug_ae_freq
        join project_cocktails.ae_singlet_report_count
        where count >= %d
        where singlet in (%s);
        """ % (min_reports, ",".join(map(lambda x: '"%s"' % x, aers_drug_names)))
        IdentifyFeatures._db_conn.execute(query)
        
        event_drug_freqs = dict()
        for drug, event, freq in IdentifyFeatures._db_conn.fetchall():
            if not event in event_drug_freqs:
                event_drug_freqs[event] = dict()
            event_drug_freqs[event][drug] = freq
        
        return event_drug_freqs
    
    @staticmethod
    def get_event_drug_frequences(aers_drug_names, min_reports = 10):
        # Pull out the ae counts for the drugs in the drug set.
        query = """
        select drug_name, event_descrip_term, isr_report_id
        from effect_aers.singlet_drugs
        join effect_aers.singlet_reactions using (isr_report_id)
        where drug_name in (%s);
        """ % ",".join(map(lambda x: '"%s"' % x, aers_drug_names))
        IdentifyFeatures._db_conn.execute(query)
        
        drug_event_reports = dict()
        
        for row in IdentifyFeatures._db_conn.fetchall():
            if not drug_event_reports.has_key(row[0]):
                drug_event_reports[row[0]] = dict()
                
            if not drug_event_reports[row[0]].has_key(row[1]):
                drug_event_reports[row[0]][row[1]] = set()
                
            drug_event_reports[row[0]][row[1]].add(row[2])
            
        event_drug_freqs = dict()
        for drug,event_reports in drug_event_reports.items():
            
            num_drug_reports = len(reduce(operator.or_, event_reports.values()))
            
            if num_drug_reports >= DRUG_MIMIMUM_REPORTS:
                for event,reports in event_reports.items():
                    
                    if not event_drug_freqs.has_key(event):
                        event_drug_freqs[event] = dict()
                        
                    if not event_drug_freqs[event].has_key(drug):
                        event_drug_freqs[event][drug] = float(len(reports)) / float(num_drug_reports)
        
        return event_drug_freqs
    
    @staticmethod
    def generate_features_sets(event_ind_freqs, event_drug_freqs, contingency_cutoffs, pvalue_cutoff):
        
        events = sorted(set(event_ind_freqs.keys()) & set(event_drug_freqs.keys()))
        drugs = reduce(operator.or_, [set(x.keys()) for x in event_drug_freqs.values()])
        inds = reduce(operator.or_, [set(x.keys()) for x in event_ind_freqs.values()])
        
        features_sets = []
        # Find significant enriched/depleted adverse events for the drug set over the background.
        for contingency_cutoff in contingency_cutoffs:
            pvalues = []
            for event in events:
                pvalue = IdentifyFeatures.fisher_test(event, inds, drugs, event_drug_freqs, event_ind_freqs, contingency_cutoff)
                pvalues.append((pvalue, event))

            features = [(event,pvalue) for pvalue,event in sorted(pvalues) if pvalue <= pvalue_cutoff]
            features_sets.append((features, contingency_cutoff, pvalue_cutoff))
        
        return features_sets


def main(correlated_indications, aers_names, drug_based, min_ind_reports, min_drug_reports, conting_cutoffs, pvalue_cutoff):
    
    print >> sys.stderr, "Extracting adverse event frequencies for correlated indications."
    event_ind_freqs = IdentifyFeatures.get_event_indication_frequencies(correlated_indications, min_ind_reports)
    
    if drug_based:
        print >> sys.stderr, "Extracting adverse event frequencies for the drug set."
        event_cmp_freqs = IdentifyFeatures.get_event_drug_frequences(aers_names, min_drug_reports)
        # event_cmp_freqs = IdentifyFeatures.get_event_drug_frequences_cached(aers_names, min_drug_reports)
    else:
        print >> sys.stderr, "Extracting adverse event frequencies for target indications."
        event_cmp_freqs = IdentifyFeatures.get_event_indication_frequencies(aers_names, min_ind_reports)
    
    print >> sys.stderr, "Finding significantly enriched events using the fisher exact method."
    features_sets = IdentifyFeatures.generate_features_sets(event_ind_freqs, event_cmp_freqs, conting_cutoffs, pvalue_cutoff)
    
    return features_sets

if __name__ == '__main__':
    
    indication_name = sys.argv[1]
    
    drug_based = os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt'))
    
    IdentifyFeatures._db_conn = db.cursor()
    print >> sys.stderr, "Identifying features for %s drug set." % indication_name
    
    indications = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'correlated_indications.csv')))]
    print >> sys.stderr, "Found %d cached correlated indications for the drug set." % len(indications)
    
    if drug_based:
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')))]
        print >> sys.stderr, 'Loaded drug names for set: %s, found %d' % (indication_name, len(aers_names))
    else:
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_indication_names.txt')))]
        print >> sys.stderr, 'Loaded indication names for set: %s, found %d' % (indication_name, len(aers_names))
    
    features_sets = main(indications, aers_names, drug_based, INDICATION_MINIMUM_REPORTS, DRUG_MIMIMUM_REPORTS, CONTINGENCY_FREQ_CUTOFFS, EVENT_PVALUE_CUTOFF)
    
    for features,conting_cutoff,pval_cutoff in features_sets:
        
        feature_file = 'features_fisher_%.2f_%.2f.csv' % (conting_cutoff, pval_cutoff)
        
        outfh = open(os.path.join(PATH_TO_DATA_DIR, indication_name, feature_file),'w')
        csv.writer(outfh).writerows(features)
        outfh.close()
    