#!/usr/bin/env python

"""This module contains code to find candidate abbreviations.

Functions:
find        Find abbreviations and their prefix strings.

"""
import re
import mx.TextTools as TT
from Bio import listfns

from Extracto import tokenizer
from Extracto import parentheses
from Extracto import ctype

MAX_WORDS_IN_PREFIX = 15



def _find_possible_abbrevs(str):
    # return a list of prefix, abbrev, prefix_grows

    # Get a list of all the parentheses in the sentence, ignoring the
    # ones that are bounded by alphanumeric characters on both sides.
    # This is a little more lenient than using parentheses.find.  That
    # one will discard some things that aren't parenthetical
    # statements, but are still parts of abbreviations, e.g.:
    #    staphylococcal enterotoxin B (SEB)-induced
    #    T cell receptor ( TCR )beta
    parens = parentheses.find_all(str)
    i = 0
    while i < len(parens):
        s, e = parens[i]
        if s and ctype.isalnum(str[s-1]) and \
           e<len(str)-1 and ctype.isalnum(str[e+1]):
            del parens[i]
        else:
            i += 1

    parens.sort()

    # Make a list of prefix and abbreviations.
    abbrevs = []   # prefix, abbrev
    for i in range(len(parens)):
        start, end = parens[i]

        # I'm looking for the pattern:   outside (inside)
        # - outside is everything from the beginning of the string, or
        # the beginning of this parenthetical statement.  For example,
        # in AAA (ZZZ (YYY)), ZZZ is outside, without the AAA part.
        # - inside is everything inside this parentheses, up to a
        # semicolon, excluding other parenthetical statements.  In the
        # previous example, if ZZZ were inside, YYY should be removed.

        # Figure out where the outside is.  This is either the whole
        # string up to the beginning of the sentence, or the beginning
        # of this parenthetical statement.
        s, e = 0, start
        for j in range(i):
            if parens[j][1] > end:   # If I'm inside a paren, the use
                s = parens[j][0]+1   # the start of that paren.
        outside = str[s:e]
        # Now remove any parenthetical statements that appear inside
        # here.
        outside = TT.collapse(parentheses.remove(outside))

        # Figure out where the inside is.  This is everything up to a
        # semicolon.
        inside = str[start+1:end-1]
        j = inside.find(";")
        if j >= 0:
            inside = inside[:j]
        inside = TT.collapse(parentheses.remove(inside))

        # Make sure there's actually some text here.  outside can be
        # empty if a parenthesis occurs at the beginning of a
        # sentence, or right inside another parenthesis.
        if not inside or not outside:
            continue

        # If there's more than 2 words inside, then assume the pattern:
        # abbrev (prefix)
        # Count the number of words based on whitespace and dashes.
        # e.g. "leucine zipper-like"  is 3 words
        # A word is a string with an alphabetical character.
        word_breaking_set = TT.set(TT.whitespace + "-")
        words = TT.setsplit(inside, word_breaking_set)
        words = filter(lambda x: TT.setfind(x, TT.alpha_set) >= 0, words)
        if len(words) > 2:
            prefix, abbrev = inside, outside.split()[-1]
            abbrevs.append((prefix, abbrev, 0))
        # Otherwise, assume the pattern:
        # prefix (abbrev)
        else:
            prefix, abbrev = outside, inside
            abbrevs.append((prefix, abbrev, -1))

    return abbrevs

def _clean(abbrevs):
    # Do some cleaning up on the prefix and abbrevs.

    cleaned_abbrevs = []
    for prefix, abbrev, prefix_grows in abbrevs:
        prefix = parentheses.remove(prefix)   # No parens in the prefix.
        
        # Shorten the prefix conservatively.  This will prevent
        # really long sentences from messing up the alignment.
        if prefix_grows == -1:
            words = tokenizer.tokenize_str(prefix)
            indexes = listfns.indexesof(words, ctype.isalnum)

            # I want 3 words for every abbreviation letter, up to
            # MAX_WORDS_IN_PREFIX words.
            nwords = len(abbrev) * 3
            if nwords > MAX_WORDS_IN_PREFIX:
                nwords = MAX_WORDS_IN_PREFIX
            if nwords < len(indexes):
                prefix = ''.join(words[indexes[-nwords]:])
            prefix = TT.setstrip(prefix, TT.whitespace_set)

        # A common pattern is:
        # <abbreviation> (for <long form>)
        # This, I'm going to check to see if the long form starts with
        # "for".  If the abbreviation does not start with "f", then
        # just strip the "for" from the long form.  If it does, then
        # the "for" may or may not be part of the abbreviation.  Thus,
        # test it with and without the "for" stripped.
        if prefix_grows == 0 and prefix.lower().startswith("for "):
            if abbrev[0].lower() == 'f':
                cleaned_abbrevs.append((prefix, abbrev, prefix_grows))
            prefix = TT.setstrip(prefix[3:], TT.whitespace_set)
        
        cleaned_abbrevs.append((prefix, abbrev, prefix_grows))
    return cleaned_abbrevs

def _filter(abbrevs):
    # Do some rough filtering to throw out things that can't possibly
    # be abbreviations.
    good_abbrevs = []
    for prefix, abbrev, prefix_grows in abbrevs:
        # No abbreviations that are longer than the prefix.
        if len(abbrev) > len(prefix):
            continue

        # No abbreviations that don't have any letters.
        if TT.setfind(abbrev, TT.alpha_set) < 0:
            continue

        # No abbreviations that have more words than the prefix.
        if len(abbrev.split()) > len(prefix.split()):
            continue
        
        # No abbreviations that are > 2 words.
        words = TT.setsplit(abbrev, TT.whitespace_set)
        if len(words) > 2:
            continue

        # No abbreviations that exactly match a string in the prefix.
        # However, sometimes a true abbreviation can match.  
        # e.g. Utah Test Appraising Health (UTAH)
        #      RNA Polymerase I (Pol I)           thinks I is abbrev for Pol I
        if re.search(r"\b%s\b" % re.escape(abbrev), prefix):
            continue

        good_abbrevs.append((prefix, abbrev, prefix_grows))
    return good_abbrevs

def find(str):
    """find(str) -> list of prefix, abbrev, prefix_grows

    prefix_grows is a number -1, 0, or 1, describing the direction in
    which the prefix can expand to include text.  -1 means the prefix
    includes more text to the left, and 1 means the prefix includes
    more text to the right.  0 means the text of the prefix cannot
    change.

    """
    x = _find_possible_abbrevs(str)
    x = _clean(x)
    x = _filter(x)
    return x
