"""
identify_features.py

Derive an Adverse Event profile for a set of drugs
by finding AE's which are enriched when compared to
a background of correlated indications.
"""

import os
import sys
import csv
import numpy
import MySQLdb
import operator

PATH_TO_DATA_DIR = os.path.expanduser('~/Stanford/AltmanLab/cocktails/methods/')

db = MySQLdb.connect(host="localhost", port=3307, user="root", passwd="dummy_password", db="effect_aers")
c = db.cursor()

if __name__ == '__main__':
    
    # Drugs to treat Renal Failure
    # indication_name = "renal_failure"
    # indication_query_term = "%renal fail%"
    
    # Drugs to treat [hyper]cholesterol
    # indication_name = "cholesterol"
    # indication_query_term = "%cholest%"
    
    # Drugs to treat [hyper]calcemia
    # indication_name = "calcemia"
    # indication_query_term = "%calcem%"
    
    # Drugs to treat diabetes
    # indication_name = "diabetes"
    # indication_query_term = "%diabete%"
    
    # Drugs to treat diabetes
    # indication_name = "glycemia"
    # indication_query_term = "%glycem%"
    
    # Drugs to treat weight loss
    # indication_name = "weight_loss"
    # indication_query_term = "%weight los%"
    
    # indication_name = "chronic_kidney_disease"
    # indication_query_term = "%chronic kidn%"
    
    # indication_name = "anemia"
    # indication_query_term = "%anemia%"
    
    # indication_name = "heart_failure"
    # indication_query_term = "%heart fail%"
    
    print >> sys.stderr, "Building drug set for %s using sider matched to %s" % (indication_name, indication_query_term)
    
    query = """
select drug_name
from effect_aers.aers2stitch
join
(
	select stitch_full as stitch_id
	from effect_sider.label_map
	join
	(
		select distinct label_id
		from effect_sider.indications
		where indication like "%s"
	) a using (label_id)
	group by stitch_full
) b using (stitch_id)
    """ % indication_query_term
    c.execute(query)
    
    drug_names = sorted(set([row[0] for row in c.fetchall()]))
    
    aers_names_file_path = os.path.join(PATH_TO_DATA_DIR, indication_name, 'aers_drug_names.txt')
    print >> sys.stderr, "Found %d drugs matched query, saving to %s" % (len(drug_names), aers_names_file_path)
    
    if not os.path.exists(os.path.join(PATH_TO_DATA_DIR, indication_name)):
        os.mkdir(os.path.join(PATH_TO_DATA_DIR, indication_name))
    
    outfh = open(aers_names_file_path, 'w')
    writer = csv.writer(outfh)
    for drug_name in drug_names:
        writer.writerow([drug_name])
    outfh.close()
    