"""String that carries markups (annotations), which are name-value pairs.

Classes:
mstring

"""
import types

from Bio import listfns

import rangefns
import markup_consts

WHOLE_STRING = -1

class mstring:
    """

    Methods:
    add_markup
    markups
    extract
    collapse

        string interface:
    __len___
    __getitem__
    __add__
    lots of the string functions

    """
    def __init__(self, string):
        """MarkupString(string) -> instance"""
        self._string = string
        # List of (name, value, start char, end char)
        self._markups = []
        
    def add_markup(self, name, value, start=WHOLE_STRING, end=WHOLE_STRING):
        """S.add_markup(name, value[, start][, end])"""
        if not (start == WHOLE_STRING and end == WHOLE_STRING) and \
           (start > end or start < 0 or end > len(self._string)):
            raise IndexError, "Invalid index (%d:%d)" % (start, end)
        # XXX Check to prevent duplicates?
        self._markups.append((name, value, start, end))

    def markups(self, name=None, value=None, index=None, restrict_fn=None):
        # returns list of (n, v, s, e)
        markups = []
        for i in range(len(self._markups)):
            n, v, s, e = self._markups[i]
            if name is not None and name != n:
                continue
            if value is not None and value != v:
                continue
            if index is not None:
                if s == WHOLE_STRING or e == WHOLE_STRING:
                    continue
                elif index < s or index >= e:
                    continue
            data = {
                'name' : n,
                'value' : v,
                'start' : s,
                'end' : e
                }
            if restrict_fn is not None and not restrict_fn(data):
                continue
            markups.append(self._markups[i])
        return markups

    def extract(self, name):
        """S.extract(name) -> list of mstring objects"""
        # Make sure I only have the markups with this name.
        markups = filter(lambda x, n=name: x[0] == n, self._markups)
        ranges = []
        for n, v, s, e in markups:
            if s == e == WHOLE_STRING:
                ranges.append((0, len(self)))
                # One I have the whole string, there are no more ranges.
                break
            else:
                ranges.append((s, e))
        ranges = rangefns.munge(ranges)
        pieces = []
        for s, e in ranges:
            pieces.append(self[s:e])
        return pieces

    def collapse(self):
        """S.collapse() -> mstring with clean whitespace"""
        # Turn all the whitespace characters into a single space.  If
        # there are multiple whitespace tokens in a row, delete the
        # subsequent ones.
        ns = self[:]
        i = 0
        while i < len(ns):
            if ns._string[i].isspace():
                # If the previous character was a space, then delete this one.
                if i > 0 and ns._string[i-1].isspace():
                    ns._del(i)
                else:
                    ns._string = ns._string[:i] + ' ' + ns._string[i+1:]
                    i += 1
            else:
                i += 1
        return ns

    def _del(self, index):
        # This is a private function that should not be exported.
        # mstrings should be immutable for the client.
        # The simplest thing to do would be to do:
        # self[:index] + self[index+1:]
        # but this would break up markups.
        self._string = self._string[:index] + self._string[index+1:]
        markups = []
        for i in range(len(self._markups)):
            n, v, s, e = self._markups[i]
            if s != WHOLE_STRING and s > index:
                s -= 1
            if e != WHOLE_STRING and e > index:
                e -= 1
            if e > s or (s == e == WHOLE_STRING):
                markups.append((n, v, s, e))
        self._markups = markups

    def __str__(self):
        return self._string

    def __repr__(self):
        return "mstring(%r)" % self._string

    def __len__(self):
        return len(self._string)

    def __getitem__(self, key):
        if type(key) is types.SliceType:
            return self._getslice(key.start, key.stop)
        else:
            return self._getitem(key)

    def _getslice(self, start, end):
        # If not explicit, end is set to sys.maxint
        if end > len(self._string):
            end = len(self._string)
        # Do some checking to make sure the indexes are reasonable.
        if start < 0 or start > len(self._string):
            raise ValueError, "start %d out of range" % start
        if end < 0 or end > len(self._string):
            raise ValueError, "end %d out of range" % end

        # Create a new object for the slice.
        ns = mstring(self._string[start:end])
        # Now get the proper markups.
        markups = []
        newlen = len(ns._string)
        for n, v, s, e in self._markups:
            if s != WHOLE_STRING and e != WHOLE_STRING:
                # Ignore markups that don't overlap.  If the markup starts
                # after slice or ends before slice, then ignore it.  The
                # exception is if there's a markup that starts and ends at
                # the end of the slice.  This might be denoting some sort
                # of boundary (e.g. end of sentence), then keep it.
                s, e = s-start, e-start
                if (s >= newlen or e <= 0) or (s == newlen and e > s):
                    continue
                if s < 0:
                    s = 0
                if e > newlen:
                    e = newlen
            markups.append((n, v, s, e))
        ns._markups = markups
        return ns

    def _getitem(self, index):
        if index < 0:
            if index < -len(self._string):
                raise IndexError, "mstring index out of range"
            index = index + len(self._string)
        elif index >= len(self._string):
            raise IndexError, "mstring index out of range"
        return self._getslice(index, index+1)
    
    def __add__(self, string):
        ns = self[:]
        l = len(ns._string)
        ns._string = ns._string + str(string)
        if hasattr(string, '_markups'):
            markup_dict = listfns.asdict(ns._markups)
            for n, v, s, e in string._markups:
                if s != WHOLE_STRING and e != WHOLE_STRING:
                    s, e = s+l, e+l
                m = n, v, s, e
                if not markup_dict.has_key(m):
                    markup_dict[m] = 1
                    ns._markups.append(m)
        return ns

    def __eq__(self, other):
        if type(other) == types.StringType:
            return self._string == other
        return self._string == other._string and \
               self._markups == other._markups

    def lstrip(self):
        ns = self[:]
        while ns and ns[0].isspace():
            ns = ns[1:]
        return ns
    def rstrip(self):
        ns = self[:]
        while ns and ns[-1].isspace():
            ns = ns[:-1]
        return ns
    def strip(self):
        ns = self.lstrip()
        ns = ns.rstrip()
        return ns
    
    def join(self, sequence):
        ns = mstring("")
        for i in range(len(sequence)):
            ns = ns + sequence[i]
            if i < len(sequence)-1:
                ns = ns + self
        return ns

    def split(self, sep=None, maxsplit=-1):
        pieces = self._string.split(sep, maxsplit)
        np = []
        start = 0
        for piece in pieces:
            i = self._string.index(piece, start)
            np.append(self[i:i+len(piece)])
            start = i + len(piece)
        return np
    
    def replace(self, old, new, maxsplit=None):
        ns = self[:]
        nreplaced = 0
        start = 0
        while maxsplit is None or nreplaced < maxsplit:
            i = ns.find(str(old), start)
            if i < 0:
                break
            ns = ns[:i] + new + ns[i+len(old):]
            start = i+1
            nreplaced += 1
        return ns

    class _apply_on_string:
        def __init__(self, mstring, attr):
            self.mstring = mstring
            self.attr = attr
        def __call__(self, *args, **keywds):
            ns = self.mstring[:]
            fn = getattr(self.mstring._string, self.attr)
            ns._string = fn(*args, **keywds)
            return ns
    class _apply_converted:
        def __init__(self, mstring, attr):
            self.mstring = mstring
            self.attr = attr
        def __call__(self, arg, *args, **keywds):
            fn = getattr(self.mstring._string, self.attr)
            return fn(str(arg), *args, **keywds)
    def __getattr__(self, attr):
        if attr in ['isalnum', 'isalpha', 'isdigit', 'islower', 'isspace',
                    'istitle', 'isupper']:
            return getattr(self._string, attr)
        elif attr in ['endswith', 'startswith',
                    'find', 'rfind', 'index', 'rindex',
                    'count']:
            return self._apply_converted(self, attr)
        elif attr in ['lower', 'upper',
                      'capitalize', 'swapcase', 'title', 'translate']:
            return self._apply_on_string(self, attr)
        raise AttributeError, "no attribute %r" % attr
