#!/usr/bin/env python

"""This module contains code to calculate features for abbreviations.


Functions:
make_vector    Make a feature vector describing an alignment


"""
# Possible other features:
# - allow suffixes at the end?  Interleukin (IL-6), alternative splicing (AS-I)
import re

from Bio import listfns

from Extracto import ctype
from Extracto import texhyphen
from Extracto import refns


# A word is a string of alphanumeric characters.
_find_words_re = re.compile(r"([\w\d]+)")
def _find_words(string, start=0):
    # Return a list of (start, end) of the words.
    words = refns.re_findall(_find_words_re, string, start=start)
    words = [x.span() for x in words]
    return words

def _is_word_start(string, i):
    if i == 0:
        return 1
    s = string[i-1]
    return ctype.isspace(s) or ctype.ispunct(s)

def _is_word_end(string, i):
    if i == len(string)-1:
        return 1
    s = string[i+1]
    return ctype.isspace(s) or ctype.ispunct(s)

class LowerAbbrev:
    # Percentage of letters in abbreviation that's lower case.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not nletters:
            return 0.0
        def is_lower(s):
            return ctype.isalpha(s) and s.islower()
        nlower = len(listfns.indexesof(abbrev, is_lower))
        return float(nlower) / nletters

class WordBegin:
    # Percent of letters in abbreviation aligned on the beginning of a
    # word.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not nletters:
            return 0.0
        count = 0
        for i in alpha_indexes.keys():
            count += _is_word_start(prefix, i)
        return float(count) / nletters

class WordEnd:
    # Allow abbreviation to align at end of word.
    # Glutathione S-transferases (GSTs)
    # interferon (IFN)
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not nletters:
            return 0.0
        count = 0
        for i in alpha_indexes.keys():
            count += _is_word_end(prefix, i)
        return float(count) / nletters

class SyllableBoundary:
    # Percent of letters in abbreviation aligned on a syllable
    # boundary.
    def __init__(self):
        self._syllable_splitter = texhyphen.SyllableSplitter()
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not alpha_indexes:
            return 0.0
        # And get a list of the indexes in the prefix that's aligned.
        aligned_indexes = alpha_indexes.keys()
        aligned_indexes.sort()

        # Make a list of the indexes of the syllable boundaries.
        splitter = self._syllable_splitter
        syllable_indexes = {}
        next_aligned = 0
        for start, end in prefix_words:
            # Optimization: Only split this word if there's something
            # aligned here.
            # Find the next aligned index that's on or after this word.
            while next_aligned < len(aligned_indexes) and \
                  aligned_indexes[next_aligned] < start:
                next_aligned += 1
            # If there are no more aligned indexes, then I'm done.
            if next_aligned >= len(aligned_indexes):
                break
            # If this word is before the next aligned index, then skip
            # it.
            if aligned_indexes[next_aligned] >= end:
                continue
            # There's something aligned at this word.  If the first
            # character is the only character aligned, then just add
            # the first character to the syllable_indexes, since the
            # beginning of a word is a syllable.  I don't need to
            # split it.
            if aligned_indexes[next_aligned] == start and \
               (next_aligned+1 == len(aligned_indexes) or \
                aligned_indexes[next_aligned+1] >= end):
                syllable_indexes[start] = 1
                continue

            # OK, there's something aligned in the middle of this
            # word.  Split it.
            word = prefix[start:end]
            syllables = splitter.split(word)
            if len(syllables) == 1:
                # do special handling for a common case...
                syllable_indexes[start] = 1
            else:
                i = start
                for s in syllables:
                    syllable_indexes[i] = 1
                    i += len(s)

        count = 0
        for i in alpha_indexes.keys():
            if syllable_indexes.has_key(i):
                count += 1

        return float(count) / nletters

class HasNeighbor:
    # Number of things in the alignment to the right of something else.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        count = 0
        for i in indexes.keys():
            if indexes.has_key(i-1):
                count += 1
        return float(count) / len(indexes)
        
class Aligned:
    # Percentage of abbreviation that's aligned.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        return float(len(indexes)) / len(abbrev)

class LettersAligned:
    # Percentage of letters in the abbreviation that's aligned.
    # May be linearly dependent upon the other features.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not nletters:
            return 0.0
        return float(len(alpha_indexes)) / nletters

class Capitalization:
    # Percent of capitalized letters in the alignment that's the same.
    # This is probably not very informative.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        # Count the number of characters where prefix and abbrev
        # are aligned, and their capitalization is the same.
        count = 0
        for ai, pi in zip(range(len(alignment)), alignment):
            if pi is None:
                continue
            p, a = prefix[pi], abbrev[ai]
            count += (p.islower() == a.islower())

        return float(count) / len(abbrev)

class UnusedWords:
    # Number of words in the prefix form that's not in the abbreviation.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        # Return the number of words without anything aligned in it.
        # XXX should I skip things like "the"?

        # Find the first aligned index.  Then, count the number of words
        # after that.
        if not indexes:
            return 0.0

        aligned_indexes = indexes.keys()
        aligned_indexes.sort()

        next_aligned = 0
        nunused = 0
        for start, end in prefix_words:
            # Ignore the words that appear either too far to the right
            # or left of the prefix.
            if (prefix_grows == -1 and end <= aligned_indexes[0]) or \
               (prefix_grows == 1 and start > aligned_indexes[-1]):
                continue
            
            # Find the next aligned index that's on or after this word.
            while next_aligned < len(aligned_indexes) and \
                  aligned_indexes[next_aligned] < start:
                next_aligned += 1
            # If there's no more aligned indexes, then this word is unused.
            if next_aligned >= len(aligned_indexes):
                nunused += 1
            # If this word is before the next aligned index, then this
            # word is unused.
            elif aligned_indexes[next_aligned] >= end:
                nunused += 1
            
        # Divide by the length of the abbreviation.  We should allow
        # more unused words in longer abbreviations.
        return float(nunused) / len(abbrev)

class AlignsPerWord:
    # Count the number of aligned characters per word.  Abbreviations
    # often have multiple characters per word.
    def score(self, prefix, abbrev, prefix_grows, nletters, alignment,
              indexes, alpha_indexes, prefix_words):
        if not indexes:
            return 0.0
        naligned = len(indexes)

        aligned_indexes = indexes.keys()
        aligned_indexes.sort()
        next_aligned = 0
        nwords = 0
        for start, end in prefix_words:
            # Find the next aligned index that's on or after this word.
            while next_aligned < len(aligned_indexes) and \
                  aligned_indexes[next_aligned] < start:
                next_aligned += 1
            # If there's no more aligned indexes, then I'm done.
            if next_aligned >= len(aligned_indexes):
                break
            # If this word is before the next index, then move on.
            if aligned_indexes[next_aligned] >= end:
                continue
            # Otherwise, this word is used.
            nwords += 1
        
        if not nwords:
            return 0.0
        return float(naligned) / nwords
    

FEATURES = [
    LowerAbbrev(),         # Describes the abbreviation.
    Aligned(),             # Describes the alignment.
    UnusedWords(),
    AlignsPerWord(),
    WordBegin(),           # Describes where the letters aligned.
    WordEnd(),
    SyllableBoundary(),
    HasNeighbor(),
    ]
FEATURE_SCORE_FNS = tuple([x.score for x in FEATURES])


def make_vector(prefix, abbrev, prefix_grows, alignment):
    """make_vector(prefix, abbrev, prefix_grows, alignment) -> vector

    Make a feature vector describing this alignment.  alignment is a
    list parallel to abbrev indicating where each character is aligned
    to prefix.

    """
    # Get the number of alphabetical letters in the abbreviation.
    nletters = len(filter(ctype.isalpha, abbrev))
    # Make of dictionary of the indexes into the alignment.
    indexes = listfns.asdict(alignment)
    if indexes.has_key(None):
        del indexes[None]
    # Make a dictionary of the indexes of the alphabetical characters
    # in the alignment.
    alpha_indexes = {}   # indexes of alignment that are alphabetic.
    for i in indexes.keys():
        if ctype.isalpha(prefix[i]):
            alpha_indexes[i] = 1
    # Make a list of all the words in the prefix.
    prefix_words = _find_words(prefix)

    params = prefix, abbrev, prefix_grows, nletters, alignment, \
             indexes, alpha_indexes, prefix_words
    return _apply_score_fns(FEATURE_SCORE_FNS, params)

def _apply_score_fns(score_fns, params):
    return [x(*params) for x in score_fns]

# Try and load C implementations of functions.  If I can't,
# then just ignore and use the pure python implementations.
try:
    import cfeatures
except ImportError:
    pass
else:
    _apply_score_fns = cfeatures._apply_score_fns
