"""collocation.py

Classes:
LikelihoodRatioModel    Calculates collocation scores based on lik. ratios.

Functions:
find                    Find the collocations in a list of tokens.

"""
import time

from Bio import mathfns
from Bio.MultiProc import Task, Scheduler

from Extracto import ngram

class LikelihoodRatioModel:
    """LikelihoodRatioModel

    Likelihood ratio of:
    P(w2|w1) = P(w2|^w1)    seeing w2 is independent of w1
    P(w2|w1) != P(w2|^w1)   seeing w2 depends on if I saw w1

    P(w2|w1) is a binomial distribution.  Take all the bigrams with
    w1, and calculate the probability of seeing a specific number of
    w2, given some probability.

    In case 1, p=P(w2) = C(c2)/N
    In case 2, p=P(w2|w1) = P(w2, w1)/P(w1) = C(w2, w1)/C(w1)
               p=P(w2|^w1) = P(w2, ^w1)/P(^w1) = (C(w2)-C(w1, w2)) / (N-C(w1))


    Methods:
    create  Create a model.
    load    Load the model from a file.
    save    Save the model to a file.
    find    Find the collocations in a list of tokens.

    """
    def __init__(self, handle=None, score_cutoff=20.0, ignore_case=1):
        """LikelihoodRatioModel([handle][, score_cutoff])"""
        # by default, takes things that are 1,000,000 more likely than not.
        self._likelihoods = None
        self._score_cutoff = score_cutoff
        self._ignore_case = ignore_case
        if handle:
            self.load(handle)

##    def __call__(self, tokens):
##        return self.find(tokens)
    
    def create(self, corpus, process_fn, update_fn=None, nprocs=1):
        """S.create(corpus, process_fn[, update_fn][, nprocs])

        Create a likelihood ratio bigram model.  corpus is a Corpus
        object.  process_fn should take a handle to a string and
        return a list of lists of tokens.  Each list of tokens should
        be its own sentence or clause -- collocations will not be
        collected across them.

        """
        if nprocs < 1 or nprocs > 1000:
            raise ValueError, "nprocs %d out of range" % nprocs
        if nprocs == 1:
            word_counts, bigram_counts = \
                         self._collect_wordcounts(
                0, 1, corpus, process_fn, update_fn)
        else:
            word_counts, bigram_counts = {}, {}
            def save_counts(
                task, word_counts=word_counts, bigram_counts=bigram_counts):
                wc, bc = task.retval
                for k, v in wc.items():
                    word_counts[k] = word_counts.get(k, 0) + v
                for k, v in bc.items():
                    bigram_counts[k] = bigram_counts.get(k, 0) + v
            scheduler = Scheduler.Scheduler(nprocs, finish_fn=save_counts)
            for i in range(nprocs):
                t = Task.Task(target=self._collect_wordcounts,
                              args=(i, nprocs, corpus, process_fn, update_fn))
                scheduler.add(t)
            while scheduler.run():
                time.sleep(0.01)

        self._likelihoods = self._calc_scores(word_counts, bigram_counts)

    def find(self, tokens):
        """S.find(tokens) -> list of (start, end, score)

        Find all the significant collocations in a list of tokens.

        """
        if not self._likelihoods:
            raise ValueError, "No model -- please load or create one."
        if self._ignore_case:
            tokens = [x.lower() for x in tokens]

        indexes = listfns.indexesof(tokens, ctype.isalnum)
        
        collocs = []  # list of start, end index (into _text_indexes), score
        i = 0
        while i < len(indexes)-1:
            i1, i2 = indexes[i], indexes[i+1]
            t1, t2 = tokens[i1], tokens[i2]
            score = self._likelihoods.get((t1, t2), None)
            if score is not None and score >= self._score_cutoff:
                collocs.append((i1, i2+1, score))
            i += 1

        # Now merge together significant bigrams.
        i = 0
        while i < len(collocs)-1:
            s1, e1, score1 = collocs[i]
            s2, e2, score2 = collocs[i+1]
            if s2 <= e1:   # they overlap
                collocs[i] = s1, e2, min(score1, score2)
                del collocs[i+1]
            else:
                i += 1

        return collocs

    def load(self, handle):
        """S.load(handle)

        Load a likelihood ratio model from a file.

        """
        self._likelihoods = {}
        while 1:
            lines = handle.readlines(16384)
            if not lines:
                break
            for line in lines:
                w1, w2, score = line.rstrip().split()
                score = float(score)
                if score < self._score_cutoff:
                    break
                self._likelihoods[(w1, w2)] = score

    def save(self, handle):
        """S.save(handle)

        Save a likelihood ratio model to a file.
        
        """
        data = self._likelihoods.items()
        # Sort by descending score, increasing alphabet
        schwartz = [(-score, bigram, score) for (bigram, score) in data]
        schwartz.sort()
        for x, (w1, w2), score in schwartz:
            handle.write("%s %s %g\n" % (w1, w2, score))
            handle.flush()
        
    def _collect_wordcounts(self, start, skip, corpus, process_fn, update_fn):
        unigrams, bigrams = {}, {}
        for i in range(start, len(corpus), skip):
            handle, name = corpus[i], corpus.name(i)
            if update_fn:
                update_fn(name)
            token_groups = process_fn(handle)
                
            for tokens in token_groups:
                if self._ignore_case:
                    tokens = [x.lower() for x in tokens]
                # Now count the words and bigrams.
                u = ngram.count(1, tokens)
                for k, v in u.items():
                    unigrams[k] = unigrams.get(k, 0) + v
                b = ngram.count(2, tokens)
                for k, v in b.items():
                    bigrams[k] = bigrams.get(k, 0) + v
        return unigrams, bigrams
    
    def _calc_scores(self, word_counts, bigram_counts):
        def L(k, n, x):
            # binomial distribution without the coefficient
            # from Manning 173
            return x**k * (1.0-x)**(n-k)
        def log2(n):
            return mathfns.safe_log2(n, zero=-1000.0)

        likelihoods = {}
        all_words = word_counts.keys()      # All the words in the corpus.
        # Total number of words in the corpus.
        N = float(reduce(lambda x, y: x+y, word_counts.values()))
        for w1, w2 in bigram_counts.keys():
            c1, c2, c12 = word_counts[w1], word_counts[w2], \
                          bigram_counts[(w1, w2)]
            c1, c2, c12 = map(float, (c1, c2, c12))
            #print w1, w2, c1, c2, c12, N
            
            p = c2/N
            p1 = c12/c1
            p2 = (c2-c12)/(N-c1)
            log_lambda = log2(L(c12, c1, p)) + log2(L(c2-c12, N-c1, p)) - \
                         log2(L(c12, c1, p1)) - log2(L(c2-c12, N-c1, p2))
            likelihoods[(w1, w2)] = -2*log_lambda
        return likelihoods

def find(tokens, model):
    """find(tokens, model) -> ranges"""
    collocs = model.find(tokens)
    return [(start, end) for (start, end, score) in collocs]
