"""

Classes:
MySQLDatabase    Handles accesses to MySQL databases.
OracleDatabase   NOT IMPLEMENTED

"""
import sys

import mx.TextTools as TT

class Database:
    def query(self, query):
        raise NotImplementedError

class MySQLDatabase(Database):
    # This implementation will create a new database connection and a
    # new cursor every time 'query' is called.  Keeping database
    # connections open was causing too many problems, such as:
    # - I encounter lots of "Lost connection to MySQL server" and have
    # to recreate the connection.
    # - When a MySQLdb database is open as a global variable, the
    # garbage collector will barf with an AttributeError (no _cursor)
    # at the end of the program.
    # - General instability.
    def __init__(self, **keywds):
        self._import_mysql()
        self._connection_params = keywds

    def _import_mysql(self):
        # Import MySQLdb here so that I don't pollute this module's
        # namespace if the user doesn't want to use MySQL.
        # import MySQLdb
        module = sys.modules[__name__]
        module.MySQLdb = __import__('MySQLdb')

    def _create_connection(self, params):
        return MySQLdb.connect(**params)

    def query(self, query):
        """S.query(query) -> iterator over results"""
        import time
        start = time.time()
        while 1:
            if time.time() >= start+600:   # try for 10 minutes
                raise AssertionError, "Database timed out"
            try:
                results = self._query(query)
            except MySQLdb.OperationalError, x:
                known_errors = [
                    "Lost connection to MySQL server",
                    "Can't connect to MySQL server"
                    ]
                for err in known_errors:
                    if str(x).find(err) >= 0:
                        break
                else:
                    raise
            else:
                break
        return results

    def _query(self, query):
        db = self._create_connection(self._connection_params)
        cursor = db.cursor()
        cursor.execute(query)
        # List comprehensions requires Python2.2
        #results = [x for x in iter(cursor.fetchone, None)]
        results = []
        while 1:
            x = cursor.fetchone()
            if x is None:
                break
            results.append(x)
        cursor.close()
        db.close()
        return results
            
def myescape(text):
    """myescape(text) -> text escape for MySQL"""
    text = TT.replace(text, '\\', '\\\\')
    text = TT.replace(text, '\t', '\\t')
    text = TT.replace(text, '\n', '\\n')
    text = TT.replace(text, "'", "\\'")
    return text
