"""This module wraps machine learning classifiers into objects.

training_set is a list of feature vectors
results must be {0, 1}

Classes:
NaiveBayes
NaiveBayes_Nltk

MaximumEntropy
MaximumEntropy_Biopython
MaximumEntropy_Nltk

LogisticRegression

SVM

Functions:
save   Save a classifier to a file.
load   Load a classifier from a file.

"""
from Extracto import memoize

def load(file_or_handle):
    """load(file_or_handle) -> classifier"""
    import pickle
    import filefns
    handle = filefns.openfh(file_or_handle)
    return pickle.load(handle)
    
def save(classifier, file_or_handle):
    """save(classifier, file_or_handle)"""
    import pickle
    import filefns
    handle = filefns.openfh(file_or_handle, 'w')
    pickle.dump(classifier, handle, 1)
    
class Classifier:
    """Does supervised classification.

    Methods:
    train      Train the classifier from some training data.
    calculate  Calculate the probability of a new observation.

    """
    def __init__(self, binner=None, feature_selector=None):
        """S([binner][, feature_selector])

        Create a new classifier.  binner is a Binner object used to
        discretize data.  feature_selector is a FeatureSelector object
        used to select significant features.

        """
        self.binner = binner
        self.feature_selector = feature_selector
        self.classifier_data = None
        self.trained = 0
    def train(self, training_set, results):
        """S.train(training_set, results)

        Train the classifier from some training data.  training_set is
        a list of observations (feature vectors).  results is a list
        of the output classes.

        """
        self.trained = 1
        if self.binner:
            self.binner.train(training_set, results)
            training_set = self.binner.bin_many_vectors(training_set)
        if self.feature_selector:
            self.feature_selector.train(training_set, results)
            training_set = [self.feature_selector.select(x)
                            for x in training_set]
        self.classifier_data = self._train(training_set, results)
    def calculate(self, x):
        """S.calculate(x) -> dict of class->log probability"""
        assert self.trained, "classifier is not trained!"
        if self.binner:
            x = self.binner.bin_vector(x)
        if self.feature_selector:
            x = self.feature_selector.select(x)
        return self._calculate(self.classifier_data, x)

    # Implement the methods below in a derived class.
    def _train(self, training_set, results):
        """S._train(training_set, results) -> parameters for classifier"""
        raise NotImplementedError
    def _calculate(self, params, x):
        """S._calculate(params, x) -> dict of class->log probability"""
        raise NotImplementedError

class NaiveBayes(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
    def _train(self, training_set, results):
        from Bio import NaiveBayes
        return NaiveBayes.train(training_set, results, typecode='d')
    def _calculate(self, params, x):
        from Bio import NaiveBayes
        return NaiveBayes.calculate(params, x, scale=1)

class NaiveBayes_Nltk(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)

    def _make_feature_fns(self, training_set, results):
        feature_fns = []
        for i in range(len(training_set[0])):
            fn = _NltkNBFeature(i)
            feature_fns.append(fn)
        return feature_fns
    def _fns2detectors(self, feature_fns):
        from nltk.classifier.feature import FunctionFeatureDetector
        return [FunctionFeatureDetector(x) for x in feature_fns]
    
    def _train(self, training_set, results):
        from Bio import listfns
        from nltk import token
        from nltk import classifier
        from nltk.classifier import naivebayes
        feature_fns = self._make_feature_fns(training_set, results)
        feature_detectors = self._fns2detectors(feature_fns)

        fd_list = naivebayes.SimpleFDList(feature_detectors)

        train_toks = []
        for i in range(len(training_set)):
            x, y = training_set[i], results[i]
            x = tuple(x)
            tok = token.Token(classifier.LabeledText(x, y))
            train_toks.append(tok)

        trainer = naivebayes.NBClassifierTrainer(fd_list)
        classifier = trainer.train(train_toks)

        from nltk import probability
        label_fdist = probability.FreqDist()
        for y in results:
            label_fdist.inc(y)
        label_pdist = probability.MLEProbDist(label_fdist)
        fval_pdist = classifier.fval_probdist()
        labels = listfns.items(results)
        return fd_list, labels, label_pdist, fval_pdist
    def _calculate(self, nb, x):
        import math
        from nltk.classifier import naivebayes
        from nltk import token
        
        fd_list, labels, label_pdist, fval_pdist = nb
        x = tuple(x)
        tok = token.Token(x)
        
        clfy = naivebayes.NBClassifier(
            fd_list, labels, label_pdist, fval_pdist)
        probs = clfy.distribution_dictionary(tok)
        for k in probs:
            probs[k] = math.log(probs[k])
        return probs

class MaximumEntropy(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
    def _make_feature_fns(self, training_set, results):
        from Bio import listfns
        klasses = listfns.items(results)
        klasses.sort()
        num_dimensions = len(training_set[0])
        
        feature_fns = []
        for i in range(num_dimensions):
            for k in klasses:
                fn = _MEFeature(i, k)
                feature_fns.append(fn)
        return feature_fns
    def _train(self, training_set, results):
        from Extracto import MaxEntropy
        feature_fns = self._make_feature_fns(training_set, results)
        x = MaxEntropy.train(training_set, results, feature_fns)
        return x
    def _calculate(self, me, x):
        from Extracto import MaxEntropy
        probs = MaxEntropy.calculate(me, x)
        probdict = {}
        for y, p in zip(me.ys, probs):
            probdict[y] = p
        return probdict

class MaximumEntropy_Biopython(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
    def _make_feature_fns(self, num_dimensions, klasses):
        feature_fns = []
        for i in range(num_dimensions):
            for klass in klasses:
                fn = _MEFeature(i, klass)
                feature_fns.append(fn)
        return feature_fns
    def _train(self, training_set, results):
        from Bio import MaxEntropy
        from Bio import listfns
        klasses = listfns.items(results)
        klasses.sort()
        num_dimensions = len(training_set[0])
        
        feature_fns = self._make_feature_fns(num_dimensions, klasses)
        x = MaxEntropy.train(training_set, results, feature_fns)
        return x
    def _calculate(self, me, x):
        from Bio import MaxEntropy
        probs = MaxEntropy.calculate(me, x)
        probdict = {}
        for y, p in zip(me.yz, probs):
            probdict[y] = p
        return probdict

class MaximumEntropy_Nltk(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
    def _make_feature_fns(self, training_set, results):
        feature_fns = []
        for i in range(len(training_set[0])):
            fn = _NltkMEFeature(i, 1)
            feature_fns.append(fn)
        return feature_fns
    def _fns2detectors(self, feature_fns):
        from nltk.classifier.feature import FunctionFeatureDetector
        return [FunctionFeatureDetector(x) for x in feature_fns]
    def _train(self, training_set, results):
        from Bio import listfns
        from nltk import token
        from nltk import classifier
        from nltk.classifier import maxent
        feature_fns = self._make_feature_fns(training_set, results)
        feature_detectors = self._fns2detectors(feature_fns)

        base_fd_list = maxent.SimpleFDList(feature_detectors)
        fd_list = maxent.GIS_FDList(base_fd_list)

        train_toks = []
        for i in range(len(training_set)):
            x, y = training_set[i], results[i]
            x = tuple(x)
            tok = token.Token(classifier.LabeledText(x, y))
            train_toks.append(tok)

        trainer = maxent.IISMaxentClassifierTrainer(fd_list)
        clfy = trainer.train(train_toks)

        # Weights for each of my features, the 1-feature, and the correction
        # feature.
        weights = clfy.weights()
        labels = listfns.items(results)
        return fd_list, labels, weights
    def _calculate(self, me, x):
        from nltk.classifier import maxent
        from nltk import token
        
        fd_list, labels, weights = me
        x = tuple(x)
        tok = token.Token(x)
        
        clfy = maxent.ConditionalExponentialClassifier(
            fd_list, labels, weights)
        probs = clfy.distribution_dictionary(tok)
        for k in probs:
            probs[k] = log(probs[k])
        return probs

class _MEFeature:
    def __init__(self, dimension, klass):
        self.dimension = dimension
        self.klass = klass
    def __call__(self, x, y):
        if y != self.klass:
            return 0
        return x[self.dimension]

class _NltkMEFeature:
    def __init__(self, dimension, klass):
        self.dimension = dimension
        self.klass = klass
    def __call__(self, token):
        vector, klass = token.text(), token.label()
        return vector[self.dimension] and klass == self.klass

class _NltkNBFeature:
    def __init__(self, dimension):
        self.dimension = dimension
    def __call__(self, token):
        vector = token.text()
        return vector[self.dimension]

class LogisticRegression(Classifier):
    def __init__(self, binner=None, feature_selector=None):
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
    def _train(self, training_set, results):
        """S._train(training_set, results) -> parameters for classifier"""
        from Bio import LogisticRegression
        return LogisticRegression.train(training_set, results)
    def _calculate(self, params, x):
        from Bio import LogisticRegression
        p0, p1 = LogisticRegression.calculate(params, x)
        probdict = {0:p0, 1:p1}
        return probdict

class SVM(Classifier):
    # This requires libsvm to be installed, and in the PYTHON_PATH.
    def __init__(self, binner=None, feature_selector=None,
                 kernel="RBF", degree=None, C=None, epsilon=None):
        """SVM([binner][, feature_selector][, kernel][, degree][, C][, epsilon])

        kernel can be "RBF", "LINEAR", "POLY", or "SIGMOID".  The
        default is "RBF".  degree is an integer and is only meaningful
        for some kernels.  C and epsilon are parameters to tune the
        SVM optimization function (see Burges' tutorial for an
        explanation).

        """
        self.kernel, self.degree = kernel, degree
        self.C, self.epsilon = C, epsilon
        
        Classifier.__init__(
            self, binner=binner, feature_selector=feature_selector)
        
    def _train(self, training_set, results):
        """S._train(training_set, results) -> parameters for classifier"""
        import os
        import tempfile
        import svm

        if not hasattr(svm, self.kernel):
            raise AssertionError, "Invalid kernel %s" % self.kernel
        keywds = {
            "svm_type" : svm.EPSILON_SVR,
            "kernel_type" : getattr(svm, self.kernel)
            }
        member2params = [('degree', 'degree'), ('C', 'C'), ('epsilon', 'eps')]
        for my_name, svm_name in member2params:
            value = getattr(self, my_name)
            if value is not None:
                keywds[svm_name] = value
        params = svm.svm_parameter(**keywds)

        dataset = svm.svm_problem(results, training_set)
        model = svm.svm_model(dataset, params)

        # The model object is not pickleable, so I have to serialize
        # it before I return it.  Unfortunately, the svm API requires
        # the model to be saved to a file!  I'll have to create a temp
        # file and read it out from there.  This is why you should
        # always pass around handles.  Grrrr....
        filename = tempfile.mktemp()
        try:
            model.save(filename)
            data = open(filename).read()
        finally:
            if os.path.exists(filename):
                os.unlink(filename)
        return data
        
    def _calculate(self, data, x):
        from Extracto.MaxEntropy import _safe_log

        model = _svm_data2model(data)
        y = model.predict(x)
        # Bound y to [0, 1]
        y = min(1, max(0, y))
        p0, p1 = _safe_log(1.0-y), _safe_log(y)
        #p0, p1 = -y, y   # use distance from hyperplane for score
        return {0:p0, 1:p1}


def _svm_data2model(data):
    import sys
    import os
    import tempfile
    global svm  # Require so __del__ can use it, in Python <= 2.1
    import svm

    # The original __del__ code in svm.py causes an exception when the
    # program is quitting:
    # Exception exceptions.AttributeError: "'None' object has no attribute 'svm_destroy_model'" in <method svm_model.__del__ of svm_model instance at 684d14> ignored
    # This happens because __del__ expects a referece to svmc, which
    # might already have been garbage collected.  Thus, I'm going to
    # write a version of the desctructor that saves a reference to
    # svmc so it doesn't go away.
    class my_svm_model(svm.svm_model):
        def __del__(self, svmc=svm.svmc):
            svm.svm_model.__del__(self)
    # Load the model from the serialized data.
    filename = tempfile.mktemp()
    try:
        open(filename, 'w').write(data)
        model = my_svm_model(filename)
    finally:
        if os.path.exists(filename):
            os.unlink(filename)
    return model

_svm_data2model = memoize.memoize(_svm_data2model)
