"""
deconv_based.py

Alternative implementation that uses the deconvolution of adverse events for
drugs that we have previously developed.
"""

import os
import csv
import sys
import numpy
import MySQLdb
import operator
import StringIO
import rpy2.robjects as robjects
from namedmatrix import NamedMatrix
from pyweka import MachineLearning as ml
from identify_features import fisher_test

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')
CORR_IND_FILE_NAME = 'correlated_indications.csv'
INDICATION_MINIMUM_REPORTS = 10
EVENT_PVALUE_CUTOFF = 0.05

class AdverseEventProfile(object):
    """
    AdverseEventProfile, initialize and analysis an AE profile for a set of drugs.
    """
    def __init__(self, identifier, drug_ids, db):
        """
        @param  identifier  A string identifier for this AdverseEventProfile object.
        @param  drug_ids    A list of STITCH drug ids, that define the drug set.
        """
        super(AdverseEventProfile, self).__init__()
        self.c = db.cursor()
        
        self.identifier = identifier
        self.drug_ids = drug_ids
        
        self.event_ind_freqs = None
        self.event_drug_freqs = None
        
        self.indications = None
        self.load_indications()
    
    def identify_features(self, cutoff):
        """
        Finds adverse event features that are enriched above background of correlated
        indications.
        """
        if self.event_ind_freqs is None:
            self._load_event_indicatons_frequencies()
        
        if self.event_drug_freqs is None:
            self._load_event_drugs_frequencies()
        
        events = sorted(set(self.event_ind_freqs.keys()) & set(self.event_drug_freqs.keys()))
        drugs = reduce(operator.or_, [set(x.keys()) for x in self.event_drug_freqs.values()])
        inds = reduce(operator.or_, [set(x.keys()) for x in self.event_ind_freqs.values()])
        
        pvalues = []
        for event in events:
            pvalue = fisher_test(event, inds, drugs, self.event_drug_freqs, self.event_ind_freqs, cutoff)
            pvalues.append((pvalue, event))
        
        features = [(event,pvalue) for pvalue,event in sorted(pvalues) if pvalue <= EVENT_PVALUE_CUTOFF]
        
        return features
    
    def _load_event_drugs_frequencies(self):
        """
        Loads the significant adverse events from the deconvolution derivation.
        """
        query = """
        select stitch_id, umls_id, drug_mean
        from project_aers.sig_drug_events
        where stitch_id in (%s);
        """ % ",".join(map(lambda x: '"%s"' % x, self.drug_ids))
        self.c.execute(query)
        data = [row for row in self.c.fetchall()]
        
        self.event_drug_freqs = dict()
        for row in data:
            drug_id = row[0]
            umls_id = row[1]
            freq = float(row[2])
            
            if not self.event_drug_freqs.has_key(umls_id):
                self.event_drug_freqs[umls_id] = dict()
            
            self.event_drug_freqs[umls_id][drug_id] = freq
    
    def _load_event_indicatons_frequencies(self):
        """
        Loads the correlated indications adverse event frequencies.
        """
        
        query = """
        select indication_descrip_term, umls_id, count(distinct(isr_report_id)) as num_reports, total_singlet_reports
        from effect_aers.indications
        join effect_aers.report_indcount using (isr_report_id)
        join effect_aers.reactions using (isr_report_id)
        join effect_aers.aers2umls using (event_descrip_term)
        join 
        (
            select indication_descrip_term, count(distinct(isr_report_id)) as total_singlet_reports
            from effect_aers.indications
            join effect_aers.report_indcount using (isr_report_id)
            where indication_descrip_term in (%s)
            and num_indications = 1
            group by indication_descrip_term
        ) a using (indication_descrip_term)
        where indication_descrip_term in (%s)
        and num_indications = 1
        group by indication_descrip_term, umls_id;
        """ % tuple([",".join(map(lambda x: '"%s"' % x, self.indications))]*2)
        self.c.execute(query)
        data = [row for row in self.c.fetchall()]
        
        self.event_ind_freqs = dict()
        
        for row in data:
            ind_term = row[0]
            ae_term = row[1]
            freq = float(row[2])/float(row[3])
            if int(row[3]) >= INDICATION_MINIMUM_REPORTS:
                
                if not self.event_ind_freqs.has_key(ae_term):
                    self.event_ind_freqs[ae_term] = dict()
                
                self.event_ind_freqs[ae_term][ind_term] = freq
    
    def load_indications(self):
        """
        Loads stored indications if available.
        """
        indications_path = os.path.join(PATH_TO_DATA_DIR, self.identifier, CORR_IND_FILE_NAME)
        
        if not os.path.exists(indications_path):
            print >> sys.stderr, "No indications file found, you can generate the file with the build_indications method."
        else:
            self.indications = [r[0] for r in csv.reader(open(indications_path))]
            print >> sys.stderr, "Found %d correlated indications." % len(self.indications)
    
    def build_indications(self):
        """
        Build a set of enriched indications based on the drug set.
        """
        # Get report ids for reports with indiaction.
        query = """
        select indication_descrip_term, isr_report_id
        from effect_aers.indications
        join effect_aers.report_drugcount using (isr_report_id)
        where num_drugs = num_verbatim
        and num_drugs = 1
        """
        self.c.execute(query)
        
        indication_reports = dict()
        
        for row in self.c.fetchall():
            ind, rid = row
            if not indication_reports.has_key(ind):
                indication_reports[ind] = set()
            indication_reports[ind].add(rid)
        
        # Get report ids for the given drugs.
        query = """
        select isr_report_id
        from effect_aers.drugs
        join effect_aers.report_drugcount using (isr_report_id)
        join effect_aers.indications using (isr_report_id)
        join effect_aers.aers2stitch using (drug_name)
        where stitch_id in (%s)
        and num_verbatim = num_drugs
        and num_verbatim = 1
        """ % ",".join(map(lambda x: '"%s"' % x, self.drug_ids))
        self.c.execute(query)
        
        drug_report_ids = set([r[0] for r in self.c.fetchall()])
        
        # Get a list of all the singlet report ids.
        query = """
        select isr_report_id
        from effect_aers.report_drugcount
        join effect_aers.indications using (isr_report_id)
        where num_drugs = num_verbatim
        and num_drugs = 1;
        """
        self.c.execute(query)
        singlet_report_ids = set([r[0] for r in self.c.fetchall()])
        
        results = []
        for i, (indication_term, report_ids) in enumerate(indication_reports.items()):
            if i % 100 == 0:
                print >> sys.stderr, ".",
            
            both = len(report_ids & drug_report_ids)
            indication = len(report_ids - drug_report_ids)
            drugs = len(drug_report_ids - report_ids)
            neither = len( singlet_report_ids - (report_ids | drug_report_ids) )
            
            if not both == 0:# and ratio > 1:
                if both < 60 or indication < 60:
                    result = robjects.r("fisher.test(matrix(c(%d, %d, %d, %d), nrow=2))" % (both, drugs, indication, neither))
                else:
                    result = robjects.r("chisq.test(matrix(c(%d, %d, %d, %d), nrow=2))" % (both, drugs, indication, neither))
                
                pvalue = result[list(result.names).index('p.value')][0]
            else:
                pvalue = 1.0
            
            results.append([indication_term, both, indication, drugs, neither, pvalue])
        
        correlated_indications = [row for row in results if row[-1] <= 0.05]
        
        # Save the correlated indications for the future.
        outfh = open(os.path.join(PATH_TO_DATA_DIR, self.identifier, CORR_IND_FILE_NAME), 'w')
        csv.writer(outfh).writerows(correlated_indications)
        outfh.close()
        
        self.indications = [r[0] for r in correlated_indications]

if __name__ == '__main__':
    
    indication_name = 'diabetes'
    
    db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="project_aers")
    c = db.cursor()
    
    drug_name_file_path = os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')
    
    drug_names = [r[0] for r in csv.reader(open(drug_name_file_path))]
    query = """
    select stitch_id
    from effect_aers.aers2stitch
    where drug_name in (%s);
    """ % ",".join(map(lambda x: '"%s"' % x, drug_names))
    c.execute(query)
    
    stitch_ids = [r[0] for r in c.fetchall()]
    
    profile = AdverseEventProfile(indication_name, stitch_ids, db)
    features = profile.identify_features(0.01)
    
    query = """
    select stitch_id, umls_id, drug_mean
    from project_aers.sig_drug_events;
    """
    c.execute(query)
    string_data = "\n".join([",".join(map(str,row)) for row in c.fetchall()])
    freqs = NamedMatrix.load_from_file(StringIO.StringIO(string_data))
    
    feat = freqs[:,[r[0] for r in features]]
    labels = [1 if rowname in stitch_ids else 0 for rowname in feat.rownames]
    
    lr = ml.Logistic(feat, labels)
    lr.cross_validate()
    