"""
Train_Predict

The difference here is that this script using the 0.4dev version of PyWeka
which includes a feature selection class that wraps forward feature selection
into the cross_validation.
"""

import os
import csv
import sys
import numpy

from namedmatrix import NamedMatrix
from pyweka import MachineLearning as ml

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')
PATH_TO_SHARE_DIR = os.path.expanduser('~/Dropbox/DocDump/')

if __name__ == '__main__':
    
    indication_name = sys.argv[1]
    
    print >> sys.stderr, "Running train_predict2 for %s." % indication_name
    
    positive_drugs = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')))]
    
    print >> sys.stderr, "Loading the single drug data..."
    singl = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_singlet_drug_ae_freq_all_gtr10.txt'), delimiter='\t')
    singl.colnames = [x.replace(' ', '_') for x in singl.colnames]
    labels = [1 if x in positive_drugs else 0 for x in singl.rownames]
    
    fs = ml.FeatureSelection(singl, labels, ml.Logistic, max_features=max(15,sum(labels)/5))
    best_features_dict = fs.forward_selection(statistical_test = "fisher.test", num_folds = 4, args=[0.01])
    best_features = best_features_dict['SelectedFeatures']
    
    print >> sys.stderr, "Forward Feature Selection:"
    print >> sys.stderr, "  %d features, Training Error: %f, Testing %f" % (len(best_features), best_features_dict['TrainingError'], best_features_dict['TestError'])
    
    print >> sys.stderr, "Loading the paired drug data..."
    pairs = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_pair_ae_freq_gtr5.txt'), delimiter='\t')
    pairs.colnames = [x.replace(' ', '_') for x in pairs.colnames]
    pairs_labels = [1 if len(set(x.split(','))&set(positive_drugs)) > 0 else 0 for x in pairs.rownames]
    
    common_features = sorted(set(best_features) & set(pairs.colnames))
    
    lr = ml.Logistic(singl[:,common_features], labels)
    lr.cross_validate()
    
    drug_pred = lr.predict(singl[:,common_features], labels)
    pair_pred = lr.predict(pairs[:,common_features], pairs_labels)
    
    pair_rownames = [row for row in pair_pred.rownames if row.find(',') != -1]
    pred_matrix = NamedMatrix(None, pair_rownames, ["Pair_Score", "Pair_Label", "Drug1_Score", "Drug2_Score"])
    
    for i,key in enumerate(pair_rownames):
        if key.find(',') != -1:
            drug1, drug2 = key.split(',')
            pred_matrix[i,0] = pair_pred[key,0]
            pred_matrix[i,1] = pair_pred[key,1]
        
            if drug1 in drug_pred.rownames:
                pred_matrix[i,2] = drug_pred[drug1,0]
            else:
                pred_matrix[i,2] = lr.coefficients['Intercept']
        
            if drug2 in drug_pred.rownames:
                pred_matrix[i,3] = drug_pred[drug2,0]
            else:
                pred_matrix[i,3] = lr.coefficients['Intercept']
    
    pred_matrix.save_to_file(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails2_%s_predictions_pairs_feat%d.csv' % (indication_name, len(best_features))))
    drug_pred.save_to_file(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails2_%s_predictions_drugs_feat%d.csv' % (indication_name, len(best_features))))
    