#!/usr/bin/env python

"""This module contains code to handle training data.


Functions:
load            Load a file of training data.
featurize       Turn training data into features usable for classification.
load_features   Load training data as features.

"""
from __future__ import nested_scopes
import operator

from Bio import MultiProc

from Extracto import datafile
import features
import aligner
import support

ABBREV_MARKER = '~'

def load(handle_or_file):
    """load(handle_or_file) -> list of (string, abbrev, aligned indexes)"""
    handle = support.open_handle_or_file(handle_or_file)

    data = []
    lines = handle.readlines()
    lines = datafile.clean(lines)
    lines = [x.split("|") for x in lines]
    # Make sure every line contains pmid, str, abbrev.
    lengths = map(len, lines)
    lengths.sort()
    if not lengths or lengths[0] != 3 or lengths[-1] != 3:
        raise SyntaxError, "training data should be in format PMID|STR|ABBREV"
    # Now process each line.
    for pmid, str, abbrev in lines:
        indexes = []
        i = 0
        for s in str:
            if s == ABBREV_MARKER:
                indexes.append(i)
            else:
                i += 1
        str = str.replace(ABBREV_MARKER, "")
        data.append((str, abbrev, indexes))
    return data

def _featurize_some(start, skip, data, update_fn):
    xs, ys = [], []
    for i in range(start, len(data), skip):
        prefix, abbrev, aligned_indexes = data[i]
        aligned_indexes.sort()
        alignments = aligner.align(prefix, abbrev)
        for alignment in alignments:
            # aligned_indexes is a list of the indexes of prefix that
            # are aligned somewhere to abbrev.  It does not indicate
            # which character of abbrev is aligned, while alignment
            # does.  To compare them, I need to get rid of the indexes
            # that are None.
            indexes = list(filter(lambda x: x is not None, alignment))
            
            x = features.make_vector(prefix, abbrev, indexes)
            y = operator.truth(indexes and (indexes == aligned_indexes))
            xs.append(x)
            ys.append(y)
            if update_fn is not None:
                update_fn(x, y)
    return xs, ys

def featurize(data, nprocs=None, update_fn=None):
    """featurize(data) -> xs, ys

    Make the training data features that can be used in a classifier.
    data is a list of (string, abbreviation, aligned indexes).

    """
    if nprocs is None:
        nprocs = 1
    if nprocs < 1 or nprocs > 100:
        raise ValueError, "Hah!  You can't run with %d processes" % nprocs
    xs, ys = [], []
    retvals = MultiProc.run(nprocs, _featurize_some, fn_args=(data, update_fn))
    for nx, ny in retvals:
        xs.extend(nx)
        ys.extend(ny)
    return xs, ys

def load_features(handle_or_file, nprocs=None, update_fn=None):
    """load_features(handle_or_file) -> xs, ys"""
    return featurize(load(handle_or_file), nprocs=nprocs, update_fn=update_fn)
