"""xml_format.py

This module provides functions that read and write mstring in an xml
format.

Functions:
load      Load an mstring object.
load_str  Load an mstring object from a string.
save      Save an mstring object.
save_str  Save an mstring object to a string.

"""
# TAGS IN THE TOKENS
# <TOKEN blah="" blah="" blah="">XXXXX</TOKEN>
# <MARKUP name="" start="" end=""><MARKUP name="" start="" end="">

import string

import mx.TextTools as TT

import mstring
import sax_helper

STRING_ELEM = "STRING"
MARKUP_ELEM = "MARKUP"

def load(handle):
    """load(handle) -> mstring"""
    return load_str(handle.read())

def load_str(s):
    """load_str(s) -> mstring"""
    # Need to get rid of unprintable characters or the XML parser will
    # barf.  Replace unprintable characters with another character so
    # we don't mess up the indexes of the markups.
    s = _replace_unprintable(s, ' ')
    events = sax_helper.xml2sax(s)
    return _sax2mstring(events)

def save(handle, ms, pretty=0):
    """save(handle, ms[, pretty])"""
    s = save_str(ms, pretty=pretty)
    handle.write(s)

def save_str(ms, pretty=0):
    """save_str(ms[, pretty]) -> str"""
    if not isinstance(ms, mstring.mstring):
        raise ValueError, "string should be a mstring object"
    # First, turn the tags into markups and encoded the names into
    # valid XML entities.
    events = _mstring2sax(ms)
    if pretty:
        events = _prettify(events)
    return sax_helper.sax2xml(events)

def _mstring2sax(ms):
    """_mstring2sax(ms) -> SAX events"""
    events = []
    events.append(('startElement', (STRING_ELEM, {})))
    events.append(('characters', (ms._string,)))
    events.append(('endElement', (STRING_ELEM,)))
    
    for name, value, start, end in ms._markups:
        attrs = {
            "name" : name,
            "value" : value,
            "start" : str(start),
            "end" : str(end)
            }
        events.append(('startElement', (MARKUP_ELEM, attrs)))
        events.append(('endElement', (MARKUP_ELEM,)))

    # Now escape the attributes.  This could be optimized by
    # integrating it with the code above.
    for i, (name, params) in zip(range(len(events)), events):
        if name == 'startElement':
            attrs = params[1]
            for key, value in attrs.items():
                attrs[key] = _escape(value)
        
    return events

def _sax2mstring(events):
    """_sax2mstring(events) -> mstring"""
    # This function iterates through the events list multiple times,
    # doing a little bit of work each time.  It can be optimized by
    # munging the steps together, but I've so far kept them apart for
    # readability.  I may change things if it gets too slow...

    # Group together consecutive 'characters' events.
    events = sax_helper.munge_characters(events)

    # Unescape the attributes.
    for i, (name, params) in zip(range(len(events)), events):
        if name == 'startElement':
            attrs = params[1]._attrs
            for key, value in attrs.items():
                attrs[key] = _unescape(value)

    # Now separate out the events for the string and markups.
    string_event = None
    markup_events = []
    in_string = 0
    for i, (event, args) in zip(range(len(events)), events):
        if event == 'characters':
            if in_string:
                if string_event is not None:
                    raise SyntaxError, "Found multiple string events"
                string_event = events[i]
        elif event == 'startElement':
            entity, attrs = args
            if entity == STRING_ELEM:
                in_string = 1
            elif entity == MARKUP_ELEM:
                markup_events.append(events[i])
        elif event == 'endElement':
            entity, = args
            if entity == STRING_ELEM:
                in_string = 0

    if string_event is None:
        s = ''
    else:
        x, (s,) = string_event

    # Make a list of the markups.
    markuplist = []
    for i, (x, (x, attrs)) in zip(range(len(markup_events)), markup_events):
        name, value, start, end = \
              attrs["name"], attrs["value"], attrs["start"], attrs["end"]
        name, value = str(name), str(value)   # don't want Unicode
        start, end = int(start), int(end)
        markuplist.append((name, value, start, end))

    # Now create the mstring object.
    ms = mstring.mstring(str(s))
    ms._markups = markuplist
    return ms

def _xml_length(sax_event):
    # Guess the length a sax event would be in XML.
    name, params = sax_event
    if name == 'characters':
        chars, = params
        return len(chars)
    elif name == 'startElement':
        elem, attrs = params
        attrlen = 0
        for name, value in attrs.items():
            attrlen += len(name) + len(value) + 3   # name="value"
            attrlen += 1    # add a space before each attribute
        return len(elem) + 2 + attrlen              # add the brackets
    elif name == 'endElement':
        elem, = params
        return len(elem) + 3     # </elem>
    else:
        raise NotImplementedError, "I don't know how to handle %s" % name

def _prettify(events, colwidth=80, lead=0, trail=0, newline='\n'):
    """_prettify(events[, colwidth][, lead][, trail][, newline]) -> events

    Make a list of SAX events print out prettily in XML.  events is a
    list of SAX events.  colwidth is the maximum number of columns per
    line.  lead is the number of leading characters reserved on the
    first line and trail is the number of trailing characters on the
    last one.  newline is the newline character for this platform.

    """
    if trail:
        raise NotImplementedError, "XXX not implemented"
    # Iterate through the events.  If I'm not in any elements, and the
    # next group will go exceed the column width, then add a newline
    # event.

    # Find indexes that are legal breakpoints.
    breakpoints = []
    open = {}     # dict of names of elements that have started and not ended
    for i, event in zip(range(len(events)), events):
        if not open:
            # If there are no open elements, add a breakpoint. 
            breakpoints.append(i)
        name, params = event
        if name == 'startElement':
            elem, x = params
            open[elem] = open.get(elem, 0) + 1
        elif name == 'endElement':
            elem, = params
            if not open.has_key(elem):
                raise SyntaxError, "Found end element without start: %s" % elem
            open[elem] -= 1
            if not open[elem]:
                del open[elem]

    # Now iterate through the events, adding newlines where necessary.
    # Also add newlines before every markup.
    lengths = map(_xml_length, events)    # length for each event
    pretty_events = []
    col = lead
    breakpoints.append(len(events))    # added for convenience.
    for i in range(len(breakpoints)-1):
        index, next_index = breakpoints[i], breakpoints[i+1]
        name, params = events[index]
        add_newline = 0
        if name == 'startElement' and params[0] == MARKUP_ELEM:
            add_newline = 1
        else:
            chars_until_breakpoint = reduce(
                lambda x, y: x+y, lengths[index:next_index], 0)
            if col != 0 and col + chars_until_breakpoint >= colwidth:
                add_newline = 1
        if add_newline:
            pretty_events.append(('characters', (newline,)))
            col = 0
        pretty_events.extend(events[index:next_index])
        col += chars_until_breakpoint
    # End with a newline.
    pretty_events.append(('characters', (newline,)))

    return pretty_events

def _replace_unprintable(str, replacement):
    # Get rid of unprintable characters.
    unprintable_set = TT.set(string.printable, 0)
    indexes = []
    start = 0
    while 1:
        i = TT.setfind(str, unprintable_set, start)
        if i < 0:
            break
        indexes.append(i)
        start = i + 1
    replacements = [(replacement, x, x+1) for x in indexes]
    return TT.multireplace(str, replacements)
    #unprintable = TT.set(string.printable, 0)
    #x = TT.setsplit(str, unprintable)
    #return TT.join(x, replacement)

_XML_ESCAPES = [
    ("&quot;", '"'),
    ("&apos;", "'"),
    ("&lt;", '<'),
    ("&gt;", '>'),
    ("&amp;", '&')    # _escape/_unescape expects this to be last.
    ]

def _escape(data):
    # Escape some data for XML attributes.  Be sure to escape the
    # ampersand first.
    for i in range(len(_XML_ESCAPES)-1, -1, -1):
        to, _from = _XML_ESCAPES[i]
        data = data.replace(_from, to)
    return data

def _unescape(data):
    # Unescape the data.  Be sure to unescape the '&' last.
    for _from, to in _XML_ESCAPES:
        # Using TextTools.replace takes much longer here!
        data = data.replace(_from, to)
    return data

