"""
Randomization signficiance testing.

eg.
python random_predict.py renal_impairment 17 100

"""

import os
import csv
import sys
import numpy
from namedmatrix import NamedMatrix
from pyweka import MachineLearning as ml
from pylab import *
import random

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

if __name__ == '__main__':
    
    indication_name = sys.argv[1]
    num_features = int(sys.argv[2])
    num_reps = int(sys.argv[3])
    
    if os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')):
        positive_drugs = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')))]
    else:
        import MySQLdb
        db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="effect_aers")
        c = db.cursor()
        
        aers_names = [row[0] for row in csv.reader(open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_indication_names.txt')))]
        query = """
        select distinct drug_name
        from effect_aers.singlet_drugs
        join effect_aers.report_indcount using (isr_report_id)
        join effect_aers.singlet_indications using (isr_report_id)
        where num_indications = 1
        and indication_descrip_term in (%s);
        """ % ",".join(map(lambda x: '"%s"' % x, aers_names))
        c.execute(query)
        positive_drugs = [row[0] for row in c.fetchall()]
    
    
    singles = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_singlet_drug_ae_freq_all_gtr10.txt'), delimiter='\t')
    singles.colnames = [x.replace(' ', '_').replace("'","").replace(',','') for x in singles.colnames]
    pairs = NamedMatrix.load_from_file(os.path.join(PATH_TO_DATA_DIR,'data','aers_pair_ae_freq_gtr5.txt'), delimiter='\t')
    pairs.colnames = [x.replace(' ', '_').replace("'","").replace(',','') for x in pairs.colnames]
    
    ranks = dict()
    labels = [1 if x in positive_drugs else 0 for x in singles.rownames]
    common_features = sorted(set(pairs.colnames) & set(singles.colnames))
    
    for i in range(num_reps):
        if (i+1)%10 == 0:
            print >> sys.stderr, ".",
        
        # Randomly select a set of $num_features features and train a logistic regression classifier.
        random_features = sorted(random.sample(common_features, num_features))
        feat = singles[:,random_features]
        lr = ml.Logistic(feat, labels)
        r = lr.cross_validate()
        preds = lr.predict(pairs[:,lr.coefficients.keys()])
        ordered_preds = sorted([(preds[j,0], preds.rownames[j]) for j in range(preds.size(0))])
        ordered_preds.reverse()
        
        current_rank = 1
        for score, pair in ordered_preds:
            if not pair in ranks:
                ranks[pair] = list()
            
            ranks[pair].append(current_rank)
            current_rank += 1
    
    outfh = open(os.path.join(PATH_TO_DATA_DIR, indication_name, 'cocktails_%s_random_ranks_%dfeat.csv' % (indication_name, num_features)), 'w')
    writer = csv.writer(outfh)
    writer.writerow(["Drug Pair", "Min Rank"])
    for pair, r in ranks.items():
        writer.writerow([pair] + [min(r)] + r)
    
    outfh.close()
