"""This module provides code to find abbreviations in text.


Classes:
AbbreviationFinder  Object that finds abbreviations.

Functions:
find    Find the abbreviations from a string.
train   Train a set of parameters for scoring abbreviations.

"""
import random
import string

import mx.TextTools as TT

from Bio import LogisticRegression
from Bio import stringfns

from Extracto import ctype
from Extracto import parentheses
from Extracto import rangefns
import features
import training_data
import paramfile
import candidate


DEFAULT_PARAMS = "pharm_abbrev1000.lr"
class AbbreviationFinder:
    """Class that finds abbreviations.

    Methods:
    find     Find the abbreviations from a string.

    """
    def __init__(self, param_file=None):
        if not param_file:
            param_file = DEFAULT_PARAMS
        self.params = paramfile.load(param_file)
    def find(self, string):
        """S.find(string) -> list of (prefix, abbrev, score)"""
        string = str(string)
        abbs = []
        candidates = candidate.find(string)
        for prefix, abbrev, prefix_grows in candidates:
            # Get all the alignments.
            alignments = aligner.align(prefix, abbrev)
            # Try and put the good alignments first.  If I find a
            # really good alignment, then I can quit early without
            # looking at the rest.
            alignments.sort()
            alignments.reverse()

            # Score all the alignments and get the best one.
            max_prob = None
            best_abb = None  # tuple of (prefix, abbrev, prob)
            for align in alignments:
                # Make sure something's aligned.
                indexes = filter(lambda x: x is not None, align)
                if not indexes:
                    continue
                x = features.make_vector(prefix, abbrev, prefix_grows, align)
                probs = LogisticRegression.calculate(self.params, x)
                # Make sure this is higher probability that the previous best.
                if max_prob is not None and probs[1] <= max_prob:
                    continue
                minprefix = _find_minimum_prefix(prefix, prefix_grows, indexes)
                max_prob = probs[1]
                best_abb = minprefix, abbrev, max_prob
                # If I've already found a really high scoring
                # alignment, then quit.
                if max_prob > 0.9:
                    break
            if max_prob is not None:
                abbs.append(best_abb)
        return abbs

def _find_minimum_prefix(prefix, prefix_grows, indexes):
    # I want the smallest amount of prefix that includes the indexes.

    start, end = 0, len(prefix)
    if prefix_grows == -1:
        # Find the first aligned word, and back up to a whitespace.
        start = min(indexes)
        start = stringfns.rfind_anychar(prefix, string.whitespace, start) + 1
    elif prefix_grows == 1:
        # Find the last aligned word, and move to the next whitespace.
        end = max(indexes)+1
        end = stringfns.find_anychar(prefix, string.whitespace, end)
        if end == -1:
            end = len(prefix)
    
    # Strip out the punctuation and whitespace, being careful not to
    # cut into parentheses.
    parens = parentheses.find_all(prefix, "()", "[]", "{}", "<>")
    parens = filter(lambda x, r=(start,end): rangefns.overlaps(r, x), parens)
    
    if parens:
        first_parens = min([x[0] for x in parens])
        last_parens = max([x[1] for x in parens])
    else:
        first_parens, last_parens = len(prefix), 0

    bad_chars = string.whitespace + string.punctuation
    start = stringfns.find_anychar(prefix, bad_chars, index=start, negate=1)
    start = min(start, first_parens)
    end = stringfns.rfind_anychar(prefix, bad_chars, index=end-1, negate=1) + 1
    end = max(end, last_parens)
    return prefix[start:end]

def find(string, param_file=None):
    """find(string[, param_file]) -> list of (prefix, abb, score)"""
    return AbbreviationFinder(param_file).find(string)

def train(training_file, param_file, nodups=1, nprocs=None,
          update_feature_fn=None,
          max_features=None,
          train_fn=LogisticRegression.train,
          train_fn_args=(), train_fn_keywds={}):
    """train(training_file, param_file)

    Train a set of parameters that can be used to score abbreviations.
    training_file is a handle or the name of a training file.
    param_file is a handle or the name of the file to save.  If a full
    path is given, then will save to this file.  Otherwise, will save
    in the data directory.

    """
    xs, ys = training_data.load_features(
        training_file, nprocs=nprocs, update_fn=update_feature_fn)
    if nodups:
        xs, ys = _uniq_features(xs, ys)
    if max_features is not None and len(xs) > max_features:
        xs, ys = _choose_random_features(xs, ys, max_features)
        
    params = train_fn(xs, ys, *train_fn_args, **train_fn_keywds)
    paramfile.save(param_file, params)

def _uniq_features(xs, ys):
    seen = {}
    nxs, nys = [], []
    for i in range(len(xs)):
        x, y = tuple(xs[i]), ys[i]
        if seen.has_key(x):
            continue
        nxs.append(x)
        nys.append(y)
        seen[x] = 1
    return nxs, nys

def _choose_random_features(xs, ys, nfeatures):
    if len(xs) <= nfeatures:
        return xs, ys
    choices = range(len(xs))
    random.shuffle(choices)
    nxs, nys = [], []
    for i in choices[:nfeatures]:
        nxs.append(xs[i])
        nys.append(ys[i])
    return nxs, nys
