"""
Script to train on singlets and predict on pairs.

Generalized from the diabetes analogue.

"""

import os
import csv
import sys
import numpy
from namedmatrix import NamedMatrix
from pyweka import MachineLearning as ml
from pylab import *
import random

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')
PATH_TO_SHARE_DIR = os.path.expanduser('~/Dropbox/DocDump/')

class TrainClassifier(object):
    """
    Train a classifier to identify a drug class based on their adverse events.
    """
    
    _sing_freqs = None
    _pair_freqs = None
    
    @staticmethod
    def singlet_frequency_matrix():
        if TrainClassifier._sing_freqs is None:
            _sing_freqs = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_singlet_drug_ae_freq_all_gtr10.txt'), delimiter='\t')
            _sing_freqs.colnames = [x.replace(' ','_') for x in _sing_freqs.colnames]
            TrainClassifier._sing_freqs = _sing_freqs
            print _sing_freqs.size()
        return TrainClassifier._sing_freqs
    
    @staticmethod
    def pair_frequency_matrix():
        if TrainClassifier._pair_freqs is None:
            _pair_freqs = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_pair_ae_freq_gtr5.txt'), delimiter='\t')
            _pair_freqs.colnames = [x.replace(' ','_') for x in _pair_freqs.colnames]
            TrainClassifier._pair_freqs = _pair_freqs
            print _pair_freqs.size()
        return TrainClassifier._pair_freqs
    
    @staticmethod
    def find_best_classifier(positive_drugs, features_lists):
        """
        Iterate through the feature sets and train a LR classifier.
        """
        freqs = TrainClassifier.singlet_frequency_matrix()
        pairs = TrainClassifier.pair_frequency_matrix()
        
        labels = [1 if x in positive_drugs else 0 for x in freqs.rownames]
        print >> sys.stderr, sum(labels)
        pairs_labels = [1 if len(set(x.split(','))&set(positive_drugs)) > 0 else 0 for x in pairs.rownames]
        print >> sys.stderr, sum(pairs_labels)
        
        best_lr = None
        best_feature_file = None
        best_auc = None
        best_features = None
        best_test_error = None
        errors = []
        
        for i,features in enumerate(features_lists):
            
            topXFeatures = min(len(features), int(len(positive_drugs)/5.))
            for numFeatures in range(topXFeatures-5,topXFeatures+10):
                
                #print >> sys.stderr, available_feature_files[i], numFeatures,
                
                feat = freqs[:, features[:numFeatures]]
                # feat = freqs[:,features]
                test = pairs[:, feat.colnames]
                train = feat[:, test.colnames] # just in case any weren't avail
                
                lr = ml.Logistic(train, labels)
                cross_val_results = lr.cross_validate()
                results = lr.test(test, pairs_labels)
                
                print >> sys.stderr, i, numFeatures, cross_val_results['AUROC'], results['AUROC'], cross_val_results['TrainingError'], cross_val_results['TestError']
                
                errors.append([i, numFeatures, cross_val_results['TrainingError'], cross_val_results['TestError'], results['TestError']])
                
                if best_lr is None or results['TestError'] < best_test_error:
                    best_lr = lr
                    #best_feature_file = available_feature_files[i]
                    best_auc = results['AUROC']
                    best_features = feat.colnames
                    best_test_error = results['TestError']
        
        return {'lr':best_lr, 'features':best_features, 'auroc':best_auc, 'errors':errors }

if __name__ == '__main__':
    
    # indication_name = "renal_failure" # X
    # indication_name = "cholesterol" # **
    # indication_name = "calcemia" # X
    # indication_name = "diabetes" # **
    # indication_name = "glycemia" # **
    # indication_name = 'weight_loss' # *
    # indication_name = 'anemia' # X
    # indication_name = 'nsaids' # *
    # indication_name = 'cox2' # *
    # indication_name = 'suicide' # **
    # indication_name = 'suicide2' # X
    # indication_name = 'blood_pressure' # **
    # indication_name = 'hypertension'
    
    # Frozen for the sake of the paper.
    # indication_name = "diabetes_paxil_prava_freeze" # **
    
    # Load up the diabetes drugs (id'ed from AERS)
    
    indication_name = sys.argv[1]
    
    if os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')):
        positive_drugs = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')))]
    else:
        import MySQLdb
        db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="effect_aers")
        c = db.cursor()
        
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_indication_names.txt')))]
        query = """
        select distinct drug_name
        from effect_aers.singlet_drugs
        join effect_aers.report_indcount using (isr_report_id)
        join effect_aers.singlet_indications using (isr_report_id)
        where num_indications = 1
        and indication_descrip_term in (%s);
        """ % ",".join(map(lambda x: '"%s"' % x, aers_names))
        c.execute(query)
        positive_drugs = [row[0] for row in c.fetchall()]
    
    available_feature_files = [f for f in os.listdir(os.path.join(PATH_TO_DATA_DIR, indication_name)) if f.find('features') != -1]
    
    features_lists = [[row[0].replace(' ','_') for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, f)))] for f in available_feature_files]
    
    results = TrainClassifier.find_best_classifier(positive_drugs, features_lists)
    best_lr = results['lr']
    best_features = results['features']
    
    freqs = TrainClassifier.singlet_frequency_matrix()
    labels = [1 if x in positive_drugs else 0 for x in freqs.rownames]
    pairs = TrainClassifier.pair_frequency_matrix()
    pairs_labels = [1 if len(set(x.split(','))&set(positive_drugs)) > 0 else 0 for x in pairs.rownames]
    
    # Get out the pair predictions
    best_lr.cross_validate()
    best_lr.test(pairs[:, best_features], pairs_labels)
    
    drug_pred = best_lr.predict(freqs[:, best_features], labels)
    pair_pred = best_lr.predict(pairs[:, best_features], pairs_labels)
    
    pair_rownames = [row for row in pair_pred.rownames if row.find(',') != -1]
    pred_matrix = NamedMatrix(None, pair_rownames, ["Pair_Score", "Pair_Label", "Drug1_Score", "Drug2_Score"])
    
    for i,key in enumerate(pair_rownames):
        if key.find(',') != -1:
            drug1, drug2 = key.split(',')
            pred_matrix[i,0] = pair_pred[key,0]
            pred_matrix[i,1] = pair_pred[key,1]
        
            if drug1 in drug_pred.rownames:
                pred_matrix[i,2] = drug_pred[drug1,0]
            else:
                pred_matrix[i,2] = best_lr.coefficients['Intercept']
        
            if drug2 in drug_pred.rownames:
                pred_matrix[i,3] = drug_pred[drug2,0]
            else:
                pred_matrix[i,3] = best_lr.coefficients['Intercept']

    pred_matrix.save_to_file(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails_%s_predictions_pairs_feat%d.csv' % (indication_name, len(best_features))))
    drug_pred.save_to_file(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails_%s_predictions_drugs_feat%d.csv' % (indication_name, len(best_features))))
    
    outfh = open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails_%s_errors.csv' % indication_name), 'w')
    writer = csv.writer(outfh)
    writer.writerows(results['errors'])
    outfh.close()
