#!/usr/bin/env python
#
#
 
"""Build swig imput file from xml encoded header files (see gccxml)."""
__author__ = "Randall J. Radmer"
__version__ = "1.0"
  
 
import sys, os
import getopt
import xml.dom.minidom as minidom
import xpath

#

# Do not generate functions for the following classes
SKIP_CLASSES = ['Kernel',
                'Stream',
                'KernelImpl',
                'StreamImpl',
                'KernelFactory',
                'StreamFactory',
                'Vec3',
                'State',
                'OpenMMException'    # this item shold not be here :(
]

# Do not generate the following functions
SKIP_METHODS = ['Context::getState',
                'Platform::loadPluginsFromDirectory'
]

# Suppress any function which references any of the following classes
HIDE_CLASSES = ['Kernel',
                'Stream',
                'KernelImpl',
                'StreamImpl',
                'KernelFactory',
                'StreamFactory'
]

INDENT = "   ";


class SwigInputBuilder:
    def __init__(self, inputFilename, outputFilename=None):
        self.typeIdDict={}

        self.doc = minidom.parse(inputFilename)
        if outputFilename:
            self.fOut = open(outputFilename, 'w')
        else:
            self.fOut = sys.stdout

        xPath = "/GCC_XML/Namespace[@name='OpenMM']/@id"
        self._openmmNamespaceID = xpath.findvalue(xPath, self.doc)
        self._orderedClassIDs=self._buildOrderedClassIDs()


    def _buildOrderedClassIDs(self, classIDs=None):
        orderedClassIDs=[]
        if not classIDs:
            classIDs=[]
            xPath = "/GCC_XML/Class[@context='%s']" % self._openmmNamespaceID
            for classNode in xpath.find(xPath, self.doc):
                className=nodeAttName(classNode)
                if className not in SKIP_CLASSES:
                    classIDs.append(nodeAttID(classNode))
        for classID in classIDs:
            self._buildBaseIDs(classID, orderedClassIDs)
        return orderedClassIDs
            
    def _buildBaseIDs(self, classID, excludedClassIDs=[]):
        xPath = "/GCC_XML/Class[@id='%s']/Base[@access='public']" % classID
        for baseNode in xpath.find(xPath, self.doc):
            baseID=nodeAttType(baseNode)
            if baseID not in excludedClassIDs:
                self._buildBaseIDs(baseID, excludedClassIDs)
        if classID not in excludedClassIDs:
            excludedClassIDs.append(classID)

    def getTypeNode(self, typeID,
                    constType=False,
                    referenceType=False,
                    pointerType=False):
        try:
            node=self.typeIdDict[typeID]
        except KeyError:
            xPath = "/GCC_XML/*[@id='%s']" % typeID
            node = xpath.find(xPath, self.doc)[0]
            self.typeIdDict[typeID]=node
        if node.tagName == 'CvQualifiedType':
            if nodeAttConst(node):
                constType=True
            (node, constType, referenceType, pointerType) = \
              self.getTypeNode(nodeAttType(node),
                               constType=constType,
                               referenceType=referenceType,
                               pointerType=pointerType)
        elif node.tagName == 'ReferenceType':
            (node, constType, referenceType, pointerType) = \
              self.getTypeNode(nodeAttType(node),
                               constType=constType,
                               referenceType=True,
                               pointerType=pointerType)
        elif node.tagName == 'PointerType':
            (node, constType, referenceType, pointerType) = \
              self.getTypeNode(nodeAttType(node),
                               constType=constType,
                               referenceType=referenceType,
                               pointerType=True)
        return (node, constType, referenceType, pointerType)

    def getTypeString(self, typeID):
        (node, constType, referenceType, pointerType) = \
          self.getTypeNode(typeID)
        returnValue = ""
        if constType:
            returnValue = "%sconst " % returnValue
        returnValue = "%s%s" \
                   % (returnValue, nodeAttName(node))
        if referenceType:
            returnValue = "%s&" % returnValue
        if pointerType:
            returnValue = "%s*" % returnValue
        return returnValue

    def getClassName(self, classID):
        (node, constType, referenceType, pointerType) = \
          self.getTypeNode(classID)
        return nodeAttName(node)

    def getClassNode(self, classID):
        return self.getTypeNode(classID)[0]

    def getClassIDs(self):
        return self._orderedClassIDs

    def writeSwigFile(self):
        self.fOut.write("\nnamespace OpenMM {\n\n")
        self.writeGlobalConstants()
        self.writeClassDeclarations()
        self.fOut.write("\n} // namespace OpenMM\n\n")


    def writeGlobalConstants(self):
        self.fOut.write("/* Global Constants */\n")
        xPath = "/GCC_XML/Variable[@context='%s']" % self._openmmNamespaceID
        for varNode in xpath.find(xPath, self.doc):
            typeString = self.getTypeString(nodeAttType(varNode))
            varName = nodeAttName(varNode)
            varInit = nodeAttInit(varNode)
            s = "static %s %s = %s;\n" % (typeString, varName, str(float(varInit)))
            self.fOut.write(s)
        self.fOut.write("\n")


    def writeClassDeclarations(self):
        self.fOut.write("/* Class Declarations */\n\n")
        for classID in self.getClassIDs():
            self.fOut.write("class %s" % self.getClassName(classID))
            xPath = "/GCC_XML/Class[@id='%s']/Base[@access='public']" % classID
            for baseNode in xpath.find(xPath, self.doc):
                baseType = self.getClassName(nodeAttType(baseNode))
                self.fOut.write(" : public %s" % baseType)
            self.fOut.write(" {\n")
            self.fOut.write("public:\n")
            self.writeEnumerations(classID)
            self.writeConstructors(classID)
            self.writeDestructor(classID)
            self.writeMethods(classID)

            self.fOut.write("};\n\n")
        self.fOut.write("\n")

    def writeEnumerations(self, classID):
        xPath = "/GCC_XML/Enumeration[@context='%s' and @access='public']" % classID
        enumNodes = xpath.find(xPath, self.doc)
        for enumNode in enumNodes:
            self.fOut.write("%senum %s {" % (INDENT, nodeAttName(enumNode)))
            argSep="\n"
            xPath = "EnumValue"
            for valueNode in xpath.find(xPath, enumNode):
                vName = nodeAttName(valueNode)
                vInit = nodeAttInit(valueNode)
                self.fOut.write("%s%s%s = %s" % (argSep, 2*INDENT, vName, vInit))
                argSep=",\n"
            self.fOut.write("\n%s};\n" % INDENT)
        if len(enumNodes)>0: self.fOut.write("\n")

    def writeConstructors(self, classID):
        xPath = "/GCC_XML/Constructor[@context='%s' and @access='public' and not(@artificial='1')]" \
               % classID
        className = self.getClassName(classID)
        classNode = self.getClassNode(classID)
        try:
            if nodeAttAbstract(classNode)=="1":
                isAbstract=True
            else:
                isAbstract=False
        except KeyError:
            isAbstract=False
        if not isAbstract:
            for cNode in xpath.find(xPath, self.doc):
                try:
                    if nodeAttVirtual(cNode)=="1":
                        virtual="virtual "
                    else:
                        virtual=""
                except KeyError:
                    virtual=""
                self.fOut.write("%s%s%s(" % (INDENT, virtual, className))
                xPath = "Argument"
                argSep=""
                for aNode in xpath.find(xPath, cNode):
                    aName=nodeAttName(aNode)
                    aType=self.getTypeString(nodeAttType(aNode))
                    self.fOut.write("%s%s %s" % (argSep, aType, aName))
                    argSep=", "
                self.fOut.write(");\n")

    def writeDestructor(self, classID):
        xPath = "/GCC_XML/Destructor[@context='%s' and @access='public' and not(@artificial='1')]" \
               % classID
        className = self.getClassName(classID)
        classNode = self.getClassNode(classID)
        try:
            dNode=xpath.find(xPath, self.doc)[0]
        except IndexError:
            return
        try:
            if nodeAttVirtual(dNode)=="1":
                virtual="virtual "
            else:
                virtual=""
        except KeyError:
            virtual=""
        self.fOut.write("%s%s~%s();\n\n" % (INDENT, virtual, nodeAttName(dNode)))

    def writeMethods(self, classID):
        className = self.getClassName(classID)
        xPath = "/GCC_XML/Method[@context='%s' and @access='public']" % classID
        for methodNode in xpath.find(xPath, self.doc):
            methodName=nodeAttName(methodNode)
            if "%s::%s" % (className, methodName) in SKIP_METHODS: continue
            if self.getClassName(nodeAttReturns(methodNode)) in HIDE_CLASSES: continue
            try:
                if nodeAttVirtual(methodNode)=="1":
                    virtual="virtual "
                else:
                    virtual=""
            except KeyError:
                virtual=""

            xPath = "Argument"
            skipMethod=False
            for aNode in xpath.find(xPath, methodNode):
                if self.getClassName(nodeAttType(aNode)) in HIDE_CLASSES:
                    skipMethod=True
            if skipMethod: continue
 
            for aNode in xpath.find(xPath, methodNode):
                (node, constType, referenceType, pointerType) = \
                  self.getTypeNode(nodeAttType(aNode))
                if referenceType:
                    aType=self.getTypeString(nodeAttType(aNode))
                    self.fOut.write("%s%%apply %s OUTPUT { %s %s };\n" %
                                     (INDENT, aType, nodeAttName(aNode), aType))
#  <xsl:value-of select="concat($myIndent, '%apply ')"/>
#  <xsl:call-template name="type"> <xsl:with-param name="type_id" select="@type"/> </xsl:call-template>
#  <xsl:value-of select="' OUTPUT { '"/>
#  <xsl:call-template name="type"> <xsl:with-param name="type_id" select="@type"/> </xsl:call-template>
#  <xsl:value-of select="concat(' ', @name)"/> };


            methodReturns=self.getTypeString(nodeAttReturns(methodNode))
            self.fOut.write("%s%s%s %s(" % (INDENT, virtual, methodReturns, methodName))
            argSep=""
            for aNode in xpath.find(xPath, methodNode):
                aName=nodeAttName(aNode)
                aType=self.getTypeString(nodeAttType(aNode))
                self.fOut.write("%s%s %s" % (argSep, aType, aName))
                argSep=", "
            self.fOut.write(");\n")


def nodeAttAbstract(node):
    return node.attributes['abstract'].value

def nodeAttConst(node):
    return node.attributes['const'].value

def nodeAttID(node):
    return node.attributes['id'].value

def nodeAttInit(node):
    return node.attributes['init'].value

def nodeAttName(node):
    return node.attributes['name'].value

def nodeAttReturns(node):
    return node.attributes['returns'].value

def nodeAttType(node):
    return node.attributes['type'].value

def nodeAttVirtual(node):
    return node.attributes['virtual'].value



def parseCommandLine():
    opts, args_proper = getopt.getopt(sys.argv[1:], 'hi:o:')
    inputFilename = ""
    outputFilename = ""
    for option, parameter in opts:
        if option=='-h': usageError()
        if option=='-i': inputFilename = parameter
        if option=='-o': outputFilename = parameter
    if not inputFilename: usageError()
    return (args_proper, inputFilename, outputFilename)

def main():
    args_proper, inputFilename, outputFilename = parseCommandLine()
    sBuilder = SwigInputBuilder(inputFilename, outputFilename)

    sBuilder.writeSwigFile()

    return


def usageError():
    sys.stdout.write('usage: %s -i inputXmlFilename [-o outputSwigFilename]\n'
         % os.path.basename(sys.argv[0]))
    sys.exit(1)

if __name__=='__main__':
    main()


