"""This module contains code for carrying out experiments.

Functions:
# Code for machine learning, old way.
ctts_dataset          Collect obs., train, test, and score a classifier.

scan_abstracts        Look for instances of a lexicon in a datafile.
find_cooccurrences    Look for cooccurrences of entities from a scan.
collect_observations  Collect observations from training and preproc files.

# Code for active learning
prioritize_sentence_training

# Miscellaneous.
collect_sentences_with_cooccurrence
collect_vocabulary        Make a list of the words from a list of sentences.

"""
from __future__ import generators

import sys
from Bio import MultiProc
from Extracto.filefns import tswrite

from Extracto.genename.experiments import \
     _join_filename, _needs_processing, _load_dataset
from Extracto.genedrug import fileformats

def ctts_dataset(training_file, gene_scan_file, drug_scan_file,
                 observations_name, classifier, classifier_name,
                 ignore_unknown=1,
                 training_dataset_name=None, training_dataset_file=None,
                 testing_dataset_name=None, testing_dataset_file=None,
                 prefix="gde", outpath='', nprocs=1, noclobber=1):
    tswrite("Starting analysis on %s.\n" % training_file)
    tswrite("Using gene scan file %s.\n" % gene_scan_file)
    tswrite("Using drug scan file %s.\n" % drug_scan_file)

    cooccurrences_file = _join_filename(
        outpath, prefix, observations_name, "cooccurrences")
    observations_file = _join_filename(
        outpath, prefix, observations_name, "observations")
    classifier_file = _join_filename(
        outpath, prefix, observations_name, training_dataset_name,
        classifier_name, "classifier")
    results_file = _join_filename(
        outpath, prefix, observations_name, training_dataset_name,
        classifier_name, testing_dataset_name, "results")
    score_file = _join_filename(
        outpath, prefix, observations_name, training_dataset_name,
        classifier_name, testing_dataset_name, "score")

    if not noclobber or \
       _needs_processing(cooccurrences_file, [gene_scan_file, drug_scan_file]):
        tswrite("Saving cooccurrences (%s) to %s (%d procs).\n" % (
            observations_name, cooccurrences_file, nprocs))
        find_cooccurrences(gene_scan_file, drug_scan_file, nprocs=nprocs,
                           outhandle=cooccurrences_file)
    else:
        tswrite("%s already exists, skipping.\n" % (cooccurrences_file))

    if not noclobber or \
       _needs_processing(observations_file, [cooccurrences_file]):
        tswrite("Saving observations (%s) to %s (%d procs).\n" % (
            observations_name, observations_file, nprocs))
        collect_observations(training_file, cooccurrences_file,
                             ignore_unknown=ignore_unknown,
                             nprocs=nprocs, outhandle=observations_file)
        
    else:
        tswrite("%s already exists, skipping.\n" % (observations_file))

    if not noclobber or \
       _needs_processing(classifier_file, [observations_file]):
        tswrite("Training classifier %s.\n" % classifier.__class__.__name__)
        if training_dataset_file:
            tswrite("Using data from %s.\n" % training_dataset_file)
        else:
            tswrite("Using all available data.\n")
        train_classifier(classifier, observations_file,
                         dataset_file=training_dataset_file,
                         outhandle=classifier_file)
    else:
        tswrite("%s already exists, skipping.\n" % classifier_file)

    if not noclobber or \
       _needs_processing(results_file, [observations_file, classifier_file]):
        tswrite("Testing classifier from %s (%d procs).\n" % (
            classifier_file, nprocs))
        if testing_dataset_file:
            tswrite("Using data from %s.\n" % testing_dataset_file)
        else:
            tswrite("Using all available data.\n")
        test_classifier(observations_file, classifier_file,
                        dataset_file=testing_dataset_file,
                        nprocs=nprocs, outhandle=results_file)
    else:
        tswrite("%s already exists, skipping.\n" % results_file)

    if not noclobber or _needs_processing(
        score_file, [training_file, results_file]):
        tswrite("Summarizing the results from %s.\n" % results_file)
        score_results(training_file, results_file, outhandle=score_file)
    else:
        tswrite("%s already exists, skipping.\n" % score_file)

    tswrite("Done!\n")

def scan_abstracts(pmid_file, lexicon_file, nprocs=1, outhandle=None):
    outhandle = outhandle or sys.stdout

    from Extracto import comments
    from Extracto import lexfind
    from Extracto import medpreproc
    from Extracto import markup_consts
    from Extracto import tokenfns

    phrases = open(lexicon_file).readlines()
    phrases = comments.remove_many(phrases)
    phrases = [x.rstrip() for x in phrases]

    scan_handle = fileformats.open_scan(outhandle, 'w')
    def do_some(start, skip, pmid_file, phrases, scan_handle):
        z = -1
        for line in open(pmid_file):
            z += 1
            if z % skip != start:
                continue
            pmid = line.rstrip()

            try:
                document = medpreproc.preprocess_title_and_abstract(pmid)
            except KeyError, x:     # pmid does not exist
                continue
            sentences = document.extract(markup_consts.SENTENCE)
            offsets = tokenfns.find_offsets(sentences, document)
            for snum in range(len(sentences)):
                sentence = sentences[snum]
                str_sentence = str(sentence)
                offset = offsets[snum]

                x = lexfind.find(
                    sentence, phrases, ignore_case=1, all_boundaries_equal=1,
                    include_abbreviations=1)
                for entry, s, e, score in x:
                    doc_s, doc_e = s+offset, e+offset
                    scan_handle.write(
                        pmid, snum, entry, str_sentence[s:e],
                        doc_s, doc_e, score)
        scan_handle.flush()
    MultiProc.run(nprocs, do_some, (pmid_file, phrases, scan_handle))

def find_cooccurrences(scan1_file, scan2_file, abbrev_cutoff=0.03,
                       nprocs=1, outhandle=None):
    from Bio import listfns
    from Extracto import rangefns
    
    # pmid -> list of (snum, 1 or 2, lex entry, phrase, start, end)
    locations = {}
    for x in fileformats.open_scan(scan1_file):
        pmid, snum, lexicon_entry, phrase, start, end, score = x
        if score is not None and score < abbrev_cutoff:
            continue
        locations.setdefault(pmid, []).append(
            (snum, 1, lexicon_entry, phrase, start, end))
    for x in fileformats.open_scan(scan2_file):
        pmid, snum, lexicon_entry, phrase, start, end, score = x
        if score is not None and score < abbrev_cutoff:
            continue
        locations.setdefault(pmid, []).append(
            (snum, 2, lexicon_entry, phrase, start, end))

    cooc_handle = fileformats.open_cooccurrences(outhandle, 'w')
    pmids = locations.keys()
    pmids.sort()

    def do_some(start, skip, pmids, locations, cooc_handle):
        for z in range(start, len(pmids), skip):
            pmid = pmids[z]

            # Sort by sentence number
            data = locations[pmid][:]
            data.sort()
            for x in data:
                snum1, type1, entry1, phrase1, start1, end1 = x
                if type1 != 1:
                    continue
                for x in data:
                    snum2, type2, entry2, phrase2, start2, end2 = x
                    if type2 != 2:
                        continue
                    # Make sure not same word.
                    if rangefns.overlaps((start1, end1), (start2, end2)):
                        continue
                    
                    cooc_handle.write(pmid,
                                      snum1, entry1, phrase1, start1, end1,
                                      snum2, entry2, phrase2, start2, end2)
        cooc_handle.flush()
    MultiProc.run(nprocs, do_some, (pmids, locations, cooc_handle))

def collect_observations(training_file, cooccurrences_file, ignore_unknown=1,
                         nprocs=1, outhandle=None):
    from Extracto import medpreproc
    from Extracto.genedrug import features

    outhandle = outhandle or sys.stdout

    has_relationship = {}   # (gene, drug) -> has relationship
    for x in fileformats.open_training(training_file):
        gene, drug, has_rel = x
        has_relationship[(gene, drug)] = has_rel

    # (gene, drug) -> list of (pmid, gsnum, dsnum, grange, drange)
    cooccurrences = {}
    for x in fileformats.open_cooccurrences(cooccurrences_file):
        pmid, gsnum, gentry, gphrase, gstart, gend, \
              dsnum, dentry, dphrase, dstart, dend = x
        gene, drug = gentry, dentry
        if ignore_unknown and (gene, drug) not in has_relationship:
            continue
        cooccurrences.setdefault((gene, drug), []).append(
            (pmid, gsnum, dsnum, (gstart, gend), (dstart, dend)))

    pairs = cooccurrences.keys()
    pairs.sort()
    observations_handle = fileformats.open_observations(outhandle, 'w')
    def do_some(start, skip,
                has_relationship, cooccurrences, pairs, observations_handle):
        for z in range(start, len(pairs), skip):
            (gene, drug) = pairs[z]
            has_rel = has_relationship.get((gene, drug), -1)
            vector = features.make_vector(
                gene, drug, cooccurrences[(gene, drug)])
            observations_handle.write(gene, drug, has_rel, vector, z)
        observations_handle.flush()
    MultiProc.run(nprocs, do_some, (
        has_relationship, cooccurrences, pairs, observations_handle))

def _load_phrase_locations(scan_file, abbrev_cutoff=0.03):
    """Return a dict of lex_phrase -> list of (pmid, sentnum, tp_start, tp_end)"""
    data = {}
    for x in fileformats.open_scan(scan_file):
        pmid, snum, lex_phrase, text_phrase, start, end, score = x
        if score is not None and score < abbrev_cutoff:
            continue
        data.setdefault(lex_phrase, []).append((pmid, snum, start, end))
        #if lex_phrase != text_phrase:
        #    data.setdefault(text_phrase, []).append((pmid, snum, start, end))
    return data

##def train_classifier(classifier, observations_file,
##                     dataset_file=None, outhandle=None):
##    from Extracto import Classifier
##    dataset = None
##    if dataset_file:
##        dataset = _load_dataset(dataset_file)
##    outhandle = outhandle or sys.stdout

##    training_set = []
##    results = []
##    for x in fileformats.open_observations(observations_file):
##        gene, drug, has_relationship, vector, index = x
##        if dataset and index not in dataset:
##            continue
##        training_set.append(vector)
        
##        # If the relationship isn't in the gold standard, ignore it.

##        if has_relationship == -1:
##            continue
##            #has_relationship = 0
##        results.append(has_relationship)

##    classifier.train(training_set, results)
##    Classifier.save(classifier, outhandle)

##def test_classifier(observations_file, classifier_file,
##                    dataset_file=None, nprocs=1, outhandle=None):
##    from Extracto import Classifier
##    outhandle = outhandle or sys.stdout
##    classifier = Classifier.load(classifier_file)
##    results_handle = fileformats.open_results(outhandle, 'w')
##    dataset = None
##    if dataset_file:
##        dataset = _load_dataset(dataset_file)

##    def do_some(start, skip, dataset, observations_file,
##                classifier, results_handle):
##        z = -1
##        for x in fileformats.open_observations(observations_file):
##            z += 1
##            if z % skip != start:
##                continue
##            gene, drug, has_relationship, vector, index = x
##            if dataset and index not in dataset:
##                continue

##            x = classifier.calculate(vector)
##            p0, p1 = x[0], x[1]
##            prediction = p1 > p0
##            if has_relationship == -1:
##                is_correct = (prediction == 0)
##            else:
##                is_correct = (prediction == has_relationship)

##            results_handle.write(gene, drug, has_relationship, prediction,
##                                 p0, p1, is_correct)
##        results_handle.flush()
##    args = dataset, observations_file, classifier, results_handle
##    MultiProc.run(nprocs, do_some, args)

##def score_results(training_file, results_file, outhandle=None):
##    from Extracto import recprec
##    outhandle = outhandle or sys.stdout

##    has_relationship = {}   # (gene, drug) -> has relationship
##    for x in fileformats.open_training(training_file):
##        gene, drug, has_rel = x
##        has_relationship[(gene, drug)] = has_rel

##    results = []    # gene, drug, has_relationship, p1
##    scored = {}   # (gene, drug) -> 1
##    for x in fileformats.open_results(results_file):
##        gene, drug, has_rel, prediction, p0, p1, is_correct = x
##        scored[(gene, drug)] = 1
##        results.append((gene, drug, has_rel, p1))

##    for (gene, drug), has_rel in has_relationship.items():
##        if (gene, drug) not in scored:
##            results.append((gene, drug, has_rel, -99999))

##    # Sort by descending score, then gene, then drug.
##    schwartz = [(-x[3], x[0], x[1], x) for x in results]
##    schwartz.sort()
##    results = [x[-1] for x in schwartz]

##    # Calculate a recall/precision curve.
##    x = [x[2] for x in results]
##    has_rel_curve = [(x==1) for x in x]
    
##    total_relationships = len([x for x in has_relationship.values() if x])
##    curve = recprec.calc_tradeoff_curve(has_rel_curve, total_relationships)
##    fscores = []
##    for rec, prec in curve:
##        if not rec and not prec:
##            f = 0.0
##        else:
##            f = 2.*prec*rec/(prec+rec)
##        fscores.append(f)
##    max_fscore = max(fscores)

##    total = len(results)

##    score_data = fileformats.ScoreData()
##    score_data.pairs_tested = total
##    score_data.num_relationships = total_relationships
##    score_data.max_fscore = max_fscore
##    for i in range(len(results)):
##        gene, drug, has_relationship, score = results[i]
##        rec, prec = curve[i]
##        x = gene, drug, has_relationship, score, rec, prec
##        score_data.data.append(x)
##    fileformats.open_score(outhandle, 'w').write(score_data)

def prioritize_sentence_training(
    sentences_file, observations_name, iteration,
    classifier,
    outpath="", prefix="act", nprocs=1, noclobber=1):
    tswrite("Starting prioritization on %s.\n" % sentences_file)

    str_iteration = "%03d" % iteration
    vocabulary_file = _join_filename(
        outpath, prefix, observations_name, str_iteration, "vocabulary") 
    sentence_observations_file = _join_filename(
        outpath, prefix, observations_name, str_iteration,
        "sentence_observations")
    sentence_classifier_file = _join_filename(
        outpath, prefix, observations_name, str_iteration,
        "sentence_classifier")
    out_sentences_file = _join_filename(
        outpath, prefix, observations_name, str_iteration,
        "sentence_training") 

    if not noclobber or \
       _needs_processing(vocabulary_file, [sentences_file]):
        tswrite("Making vocabulary file.\n")
        collect_vocabulary(sentences_file, outhandle=vocabulary_file)
        
    if not noclobber or \
       _needs_processing(sentence_observations_file,
                         [sentences_file, vocabulary_file]):
        tswrite("Saving observations (%s) to %s (%d procs).\n" % (
            observations_name, sentence_observations_file, nprocs))
        format_sentences_as_pyml(
            sentences_file, vocabulary_file,
            nprocs=nprocs, outhandle=sentence_observations_file)
    else:
        tswrite("%s already exists, skipping.\n" % sentence_observations_file)

    if not noclobber or \
       _needs_processing(sentence_classifier_file,
                         [sentences_file, sentence_observations_file]):
        tswrite("Training classifier.\n")
        train_sentence_classifier(
            classifier, sentences_file, sentence_observations_file,
            outhandle=sentence_classifier_file)
    else:
        tswrite("%s already exists, skipping.\n" % sentence_classifier_file)

    if not noclobber or \
       _needs_processing(
        out_sentences_file, [
        sentences_file, sentence_observations_file,
        sentence_classifier_file]):
        tswrite("Prioritizing training set.\n")
        prioritize_training_set(
            sentences_file, sentence_observations_file,
            sentence_classifier_file,
            outhandle=out_sentences_file)
    else:
        tswrite("%s already exists, skipping.\n" % out_sentences_file)

    tswrite("Done!\n")

def _get_words_for_relationship(sentence, word1, index1, word2, index2):
    from Extracto import cooccurrence

    x = cooccurrence.split_sentence_into_words(
        sentence, word1, index1, word2, index2)
    left_words, x, middle_words, x, right_words = x
    return left_words + middle_words + right_words

def collect_vocabulary(sentences_file, allowed_classes=None,
                       select_features_fn=None,
                       outhandle=None):
    outhandle = outhandle or sys.stdout

    vocabulary = {}   # word -> document frequency
    for x in fileformats.open_sentences(sentences_file):
        pmid, snum, gene, gs, ge, drug, ds, de, \
              sentence, has_rel, shows_rel = x

        # Only the words in my training set matter.
        if allowed_classes is not None and has_rel not in allowed_classes:
            continue
        words = {}
        gene_word, drug_word = sentence[gs:ge], sentence[ds:de]
        for word in _get_words_for_relationship(
            sentence, gene_word, gs, drug_word, ds):
            words[word] = 1
        for word in words:
            vocabulary[word] = vocabulary.get(word, 0) + 1

    if select_features_fn is not None:
        vocabulary = select_features_fn(vocabulary)
            
    words = vocabulary.keys()
    words.sort()
    vocab_handle = fileformats.open_vocabulary(outhandle, 'w')
    for i in range(len(words)):
        vocab_handle.write(words[i], i)
    vocab_handle.flush()

##def format_sentences_as_pyml(
##    training_file, sentences_file, vocabulary_file,
##    default_klasses=None, nprocs=1, outhandle=None):
##    outhandle = outhandle or sys.stdout
##    cooc2classes = cooc2classes or {}  # (drug, gene) -> class

##    from Extracto import pyml_format
    
##    word2index = {}
##    for word, index in fileformats.open_vocabulary(vocabulary_file):
##        word2index[word] = index

##    cooc2klasses = {}
##    for gene, drug, klasses_str in fileformats.open_training(training_file):
##        klasses = klasses_str.split(";")
##        cooc2klasses[(gene, drug)] = klasses
        
##    data = []  # list of (id, klass, dict)
##    for x in fileformats.open_sentences(sentences_file):
##        pmid, snum, gene, gs, ge, drug, ds, de, \
##              sentence, has_rel, shows_rel = x
##        if has_rel != 1:
##            continue
##        id = "%s-%d" % (pmid, snum)

##        klasses = cooc2classes.get((gene, drug), default_klasses)
##        if klasses is None:
##            raise AssertionError, "Missing cooccurrence"

##        dict = {}
##        gene_word, drug_word = sentence[gs:ge], sentence[ds:de]
##        words = experiments._get_words_for_relationship(
##            sentence, gene_word, gs, drug_word, ds)
##        for w in words:
##            # If the word is not in the vocabulary, ignore it.
##            if w not in word2index:
##                continue
##            i = word2index[w]
##            dict[i] = dict.get(i, 0) + 1

##        data.append((id, klasses, dict))
##    pyml_format.save_sparse(outhandle, data)

##def train_sentence_classifier(
##    classifier, sentences_file, sentence_observations_file,
##    dataset_file=None, outhandle=None):
##    from Extracto import Classifier
##    outhandle = outhandle or sys.stdout
##    dataset = None
##    if dataset_file:
##        dataset = _load_dataset(dataset_file)

##    has_relationship = {} # (pmid, sentnum, gs, ds) -> has_relationship
##    # Load the training
##    for x in fileformats.open_sentence_training(sentences_file):
##        pmid, snum, gene, gs, drug, ds, sentence, has_rel = x
##        k = pmid, snum, gs, ds
##        has_relationship[k] = has_rel

##    training_set = []
##    results = []
##    for x in fileformats.open_sentence_observations(
##        sentence_observations_file):
##        pmid, sentnum, gs, ds, vector, index = x
##        if dataset and index not in dataset:
##            continue
##        k = pmid, sentnum, gs, ds
##        if has_relationship[k] not in [0, 1]:
##            continue
##        vector = vector[:4]
##        training_set.append(vector)
##        results.append(has_rel)

##    classifier.train(training_set, results)
##    Classifier.save(classifier, outhandle)

##def prioritize_training_set(
##    sentences_file, sentence_observations_file, classifier_file,
##    outhandle=None):
##    import math
##    from Extracto import Classifier
##    outhandle = outhandle or sys.stdout
##    classifier = Classifier.load(classifier_file)

##    priorities = {}  # pmid, sentnum, gs, ds -> priority
    
##    for x in fileformats.open_sentence_observations(
##        sentence_observations_file):
##        pmid, sentnum, gs, ds, vector, index = x
##        k = pmid, sentnum, gs, ds

##        vector = vector[:4]

##        x = classifier.calculate(vector)
##        p0, p1 = math.exp(x[0]), math.exp(x[1])
##        # I want the observations closest to the hyperplane.
##        priority = 0.5 - math.fabs(p1-0.5)
##        priorities[k] = priority

##    data = []   # priority, training data
##    for x in fileformats.open_sentence_training(sentences_file):
##        pmid, snum, gene, gs, drug, ds, sentence, has_rel = x
##        if has_rel in [0, 1]:
##            priority = -1
##        else:
##            k = pmid, snum, gs, ds
##            priority = priorities[k]
##        data.append((priority, x))

##    # Sort by decreasing priority.
##    schwartz = [(-x[0], x[1][0], x[1][1], x) for x in data]
##    schwartz.sort()
##    data = [x[-1] for x in schwartz]

##    training_outhandle = fileformats.open_sentence_training(
##        outhandle, 'w')
##    for priority, x in data:
##        training_outhandle.write(*x)

def collect_sentences_with_cooccurrence(
    training_file, cooccurrence_file, outhandle=None):
    outhandle = outhandle or sys.stdout
    
    from Extracto import medpreproc
    from Extracto import docfns

    has_relationship = {}
    for x in fileformats.open_training(training_file):
        gene, drug, has_rel = x
        has_relationship[(gene, drug)] = has_rel
    
    outhandle = fileformats.open_sentences(outhandle, 'w')
    seen = {}   # (pmid, snum) -> 1
    for x in fileformats.open_cooccurrences(cooccurrence_file):
        pmid, gsnum, gentry, gphrase, gstart, gend, \
              dsnum, dentry, dphrase, dstart, dend = x
        if gsnum != dsnum:
            continue
        
        document = medpreproc.preprocess_title_and_abstract(pmid)
        sentences = docfns.extract_sentences(document)
        sentence = sentences[gsnum]
        str_sentence = str(sentence)
        
        has_rel = has_relationship.get((gentry, dentry), -1)
        outhandle.write(pmid, gsnum,
                        gentry, gstart, gend,
                        dentry, dstart, dend,
                        str_sentence, has_rel, -1)
    outhandle.flush()

##def count_words_in_sentences(sentences_file, allowed_classes=None,
##                             outhandle=None):
##    outhandle = outhandle or sys.stdout

##    docfreq = {}   # word -> document frequency
##    for x in fileformats.open_sentences(sentences_file):
##        pmid, snum, gene, gs, ge, drug, ds, de, \
##              sentence, has_rel, shows_rel = x

##        # Only the words in my training set matter.
##        if allowed_classes is not None and has_rel not in allowed_classes:
##            continue
##        words = {}
##        gene_word, drug_word = sentence[gs:ge], sentence[ds:de]
##        for word in _get_words_for_relationship(
##            sentence, gene_word, gs, drug_word, ds):
##            words[word] = 1
##        for word in words:
##            docfreq[word] = docfreq.get(word, 0) + 1

##    # Sort by decreasing frequency.
##    schwartz = [(-x[-1], x) for x in docfreq.items()]
##    schwartz.sort()
##    wordcounts = [x[-1] for x in schwartz]

##    wordcounts_handle = fileformats.open_wordcounts(outhandle, 'w')
##    for word, count in wordcounts:
##        wordcounts_handle.write(word, count)
##    wordcounts_handle.flush()
