#!/usr/bin/env python

"""This module contains code to align abbreviations to the stuff outside.

Functions:
align    Align an abbreviation to its prefix string.

"""
import math

from Bio import pairwise2
from Bio import listfns

def align(prefix, abbrev):
    """align(prefix, abbrev) -> list of alignments

    An alignment is a list of indexes.  The list is parallel to abbrev
    and indicates the index of prefix where each letter of abbrev is
    aligned.

    """
    # Pick a good gap character.
    gap_chars = ["~", "@", "$", "!", "?", "&"]
    prefix_chars, abbrev_chars = listfns.asdict(prefix), listfns.asdict(abbrev)
    for gap in gap_chars:
        if not prefix_chars.has_key(gap) and not abbrev_chars.has_key(gap):
            break
    else:
        raise AssertionError, "I could not find a gap character for %s %s" % (
            prefix, abbrev)
    # Do case insensitive matching
    prefix, abbrev = prefix.lower(), abbrev.lower()

    # To optimize, only align the letters that are in both the prefix
    # and abbrev.
    old_prefix, old_abbrev = prefix, abbrev
    in_both = listfns.intersection(
        listfns.items(prefix), listfns.items(abbrev))
    in_both = listfns.asdict(in_both)
    # Get only the letters in abbrev that appear in both.
    abbrev2oldabbrev = []  # index of abbrev to index in old abbrev
    abbrev = []
    for i in range(len(old_abbrev)):
        if in_both.has_key(old_abbrev[i]):
            abbrev2oldabbrev.append(i)
            abbrev.append(old_abbrev[i])
    abbrev = ''.join(abbrev)
    # Get only the letters in prefix that appear in both.
    prefix2oldprefix = []  # index of prefix to index in old prefix
    prefix = []
    for i in range(len(old_prefix)):
        if in_both.has_key(old_prefix[i]):
            prefix2oldprefix.append(i)
            prefix.append(old_prefix[i])
    prefix = ''.join(prefix)

    # Do the alignments.
    aligns = pairwise2.align.globalxx(prefix, abbrev, gap_char=gap)
    # Store the alignments in a dictionary so I don't save duplicate
    # ones.
    done = {}        # Don't redo alignments I've seen.
    alignments = {}  # (aligned indexes) -> 1
    for prefix_align, abbrev_align, x, x, x in aligns:
        if done.has_key((prefix_align, abbrev_align)):
            continue
        done[(prefix_align, abbrev_align)] = 1
        alignment = _make_alignment_fast(
            len(old_abbrev), prefix_align, abbrev_align,
            prefix2oldprefix, abbrev2oldabbrev,
            gap)
        alignments[tuple(alignment)] = 1
    return alignments.keys()

def _make_alignment(abbrevlen, prefix_align, abbrev_align, gap):
    alignment = [None] * abbrevlen
    abbrev_i = prefix_i = 0
    for p, a in zip(prefix_align, abbrev_align):
        if a != gap:
            if p == a:
                alignment[abbrev_i] = prefix_i
            abbrev_i += 1
        if p != gap:
            prefix_i += 1
    return alignment

def _make_alignment_fast(abbrevlen, prefix_align, abbrev_align,
                         prefix2oldprefix, abbrev2oldabbrev, gap):
    alignment = [None] * abbrevlen
    abbrev_i = prefix_i = 0
    for p, a in zip(prefix_align, abbrev_align):
        if a != gap:
            if p == a:
                oai = abbrev2oldabbrev[abbrev_i]
                opi = prefix2oldprefix[prefix_i]
                alignment[oai] = opi
            abbrev_i += 1
        if p != gap:
            prefix_i += 1
    return alignment

class _cutoff_within:
    def __init__(self, cutoff):
        self._cutoff = cutoff
    def __call__(self, score, pos, all_scores):
        return math.fabs(score - all_scores[0]) < self._cutoff
        
def _within_half_fn(score, pos, all_scores):
    return math.fabs(score - all_scores[0]) < 0.5

def _aligned(char1, char2):
    # Whether two characters are aligned (i.e. not a gap).
    if not ctype.isalnum(char1) or not ctype.isalnum(char2):
        return 0
    return char1.lower() == char2.lower()

# Try and load C implementations of functions.  If I can't,
# then just ignore and use the pure python implementations.
try:
    #raise ImportError
    import caligner
except ImportError:
    pass
else:
    _within_half_fn = caligner._within_half_fn
    _make_alignment = caligner._make_alignment
    _make_alignment_fast = caligner._make_alignment_fast
    _aligned = caligner._aligned
