"""

Functions:
make_vector      Make a feature vector.

Classes:
Abstract  The number of times a drug/gene seen in same abstract.
Sentence  The number of times a drug/gene seen in same sentence.
Series    The number of times a drug/gene seen in same series.


Keyword   The number of times a drug/gene has a keyword.

"""
from Extracto import memoize

class Abstract:
    def score(self, gene, drug, cooccurrences):
        # The number of abstracts that they have in common.
        pmids = {}
        for x in cooccurrences:
            pmids[x[0]] = 1
        return len(pmids.keys())

class Sentence:
    def score(self, gene, drug, cooccurrences):
        # Get a list of the cooccurrences in the same sentence.
        locs = {} # (pmid, sentnum) -> 1
        for pmid, gsnum, dsnum, grange, drange in cooccurrences:
            if gsnum != dsnum:   # Make sure the sentences are same.
                continue
            locs[(pmid, gsnum)] = 1
        return len(locs.keys())

def _find_npranges(sentence):
    from Extracto import npseries
    return npseries.find(sentence)
_find_npranges = memoize.memoize(_find_npranges, args2key=lambda x: str(x))

def _pmid2sentences(pmid):
    from Extracto import medpreproc
    document = medpreproc.preprocess_title_and_abstract(pmid)
    return _doc2sentences(document)

def _pmid2abbrevs(pmid, score_cutoff):
    from Extracto import medpreproc
    document = medpreproc.preprocess_title_and_abstract(pmid)
    return _doc2abbrevs(document, score_cutoff)

def _doc2sentences(document):
    from Extracto import docfns
    return docfns.extract_sentences(document)
_doc2sentences = memoize.memoize(_doc2sentences, args2key=lambda x: str(x))

def _doc2abbrevs(document, score_cutoff):
    from Extracto import docfns
    return docfns.extract_abbrevs(document, score_cutoff=0.03)
_doc2abbrevs = memoize.memoize(_doc2abbrevs, args2key=lambda x, y: (str(x), y))

class Series:
    def score(self, gene, drug, cooccurrences):
        from Extracto import rangefns
        from Extracto import tokenfns
        from Extracto import medpreproc

        # Get a list of the cooccurrences in the same sentence.
        cooccurrences = [x for x in cooccurrences if x[1] == x[2]]

        # I want a list of the sentences where the cooccurrence is
        # only in a series.
        all_locations = {}   # (pmid, snum) -> 1
        not_in_series = {}   # (pmid, snum) -> 1
        for pmid, gsnum, dsnum, grange, drange in cooccurrences:
            if (pmid, gsnum) in not_in_series:
                continue
            all_locations[(pmid, gsnum)] = 1
            document = medpreproc.preprocess_title_and_abstract(pmid)
            sentences = _doc2sentences(document)
            offsets = tokenfns.find_offsets(sentences, document)
            np_ranges = _find_npranges(sentences[gsnum])
            offset = offsets[gsnum]
            np_ranges = [(s+offset, e+offset) for (s, e) in np_ranges]
            num_same = 0
            for range in np_ranges:
                if rangefns.overlaps(range, grange) and \
                   rangefns.overlaps(range, drange):
                    # gr and dr are in the same range.
                    break
            else:
                not_in_series[(pmid, gsnum)] = 1
        return len(all_locations) - len(not_in_series)
        
class Abbreviation:
    def score(self, gene, drug, cooccurrences):
        gd_pair = gene.lower(), drug.lower()
        if gd_pair[1] < gd_pair[0]:
            gd_pair = gd_pair[1], gd_pair[0]
        
        pmids_seen = {}
        num_abbrevs = 0
        for pmid, gsnum, dsnum, grange, drange in cooccurrences:
            if pmid in pmids_seen:
                continue
            pmids_seen[pmid] = 1

            abbrevs = _pmid2abbrevs(pmid, 0.03)
            for prefix, abbrev, score in abbrevs:
                abb_pair = prefix.lower(), abbrev.lower()
                if abb_pair[1] < abb_pair[0]:
                    abb_pair = abb_pair[1], abb_pair[0]
                if gd_pair == abb_pair:
                    num_abbrevs += 1
                    break
        return num_abbrevs

class Keyword:
    def __init__(self, keyword):
        self.keyword = keyword
    def score(self, gene, drug, cooccurrences):
        from Extracto import stem
        # XXX
        from Extracto.genedrug.experiments import _get_words_for_relationship
        from Extracto import medpreproc

        seen = {}   # (pmid, snum) -> 1
        num_sents_with_keyword = 0
        for pmid, gsnum, dsnum, grange, drange in cooccurrences:
            if gsnum != dsnum:   # make sure same sentence
                continue
            if (pmid, gsnum) in seen:
                continue
            seen[(pmid, gsnum)] = 1

            document = medpreproc.preprocess_title_and_abstract(pmid)
            sentence = _doc2sentences(document)[gsnum]
            str_sentence = str(sentence)

            (gs, ge), (ds, de) = grange, drange
            gene, drug = str_sentence[gs:ge], str_sentence[ds:de]
            x = _get_words_for_relationship(str_sentence, gene, gs, drug, ds)
            words = [stem.porter(x) for x in x]
            if self.keyword in words:
                num_sents_with_keyword += 1
        return num_sents_with_keyword


##class Keyword(BySentenceMixin, Feature.Feature):
##    def __init__(self, *keywords):
##        from Bio import trie
##        Feature.Feature.__init__(self)
##        assert keywords, "No keywords specified."
##        self.keywords = keywords
##        self.keywords_trie = trie.trie()
##        for key in self.keywords:
##            self.keywords_trie[key] = 1

##    def analyze_sentence(self, pmid, sentnum, document, sentences, contents):
##        from Bio import triefind

##        # Make sure there's a drug and gene in this sentence.
##        if not _has_loctypes(contents, "drug", "gene"):
##            return

##        # Make sure the keyword exists.
##        str_sentence = str(sentences[sentnum])
##        x = triefind.find_words(str_sentence, self.keywords_trie)
##        if not x:
##            return

##        self.value += 1

FEATURES = [
    Abstract(),
    #Sentence(),
    #Series(),
    #Abbreviation(),
    #Keyword("metabol"),
    #Keyword("phenotyp"),
    ]
FEATURE_SCORE_FNS = tuple([x.score for x in FEATURES])

def make_vector(gene, drug, cooccurrences):
    """Make a feature vector describing the relationship.

    cooccurrences is a list of:
        pmid, gene sentnum, drug sentnum, gene range, drug range

    """
    params = gene, drug, cooccurrences
    return [x(*params) for x in FEATURE_SCORE_FNS]
