"""
Create header file for instantiating templated classes and methods in PySimTK simtkcommon module.
"""
import sys
import re

class Instantiator(object):
    def __init__(self, header_file_name):
        self.header_file_name = header_file_name
        self.instantiate_method_string = ""
        self.template_instantiation_string = ""
        self.aliases_string = ""
        self.instantiate_vecs([2], [1, 2, 3])
        # columns and diagonals of Mat33 and Mat66 yield these strides
        self.instantiate_vecs([3], [1, 3, 4, 6, 7])
        self.instantiate_vecs([4], [1, 4, 5])
        self.instantiate_vecs([6], [1, 6, 7])
        # self.instantiate_vecs([15], [1])
        # self.instantiate_vecs([21], [1])
        # cross products
        for type_str1 in self.iter_vec_row_names([3], [1,3,4,6,7]):
            for type_str2 in self.iter_vec_row_names([3], [1,3,4,6,7]):
                self.write_cross_product_methods(type_str1, type_str2)
        for type_str1 in self.iter_vec_row_names([2], [1,2,3]):
            for type_str2 in self.iter_vec_row_names([2], [1,2,3]):
                self.write_cross_product_methods(type_str1, type_str2)
        
    def generate_instantiations(self):
        file = open(self.header_file_name, "w")
        self.write_beginning(file)
        self.write_template_instantiations(file)
        self.write_aliases(file)
        self.write_instantiate_method(file)
        self.write_end(file)
        file.close()

    def write_aliases(self, file):
        file.write('namespace pyplusplus { namespace aliases {\n')
        file.write(self.aliases_string)
        file.write("}} // namespace pyplusplus::aliases\n")
        
    def write_template_instantiations(self, file):
        file.write(self.template_instantiation_string)

    def write_instantiate_method(self, file):
        file.write("static void instantiate() {\n")
        file.write(self.instantiate_method_string)
        file.write("} // instantiate()\n")        
        
    def get_vec_components(self, type_str):
        match = re.match(r'([^<]+)<((\d+), )?([^,]+), (\d+)>', type_str)
        if not match:
            raise ValueError('Could not parse %s' % type_str)
        vec = match.group(1)
        size = match.group(3)
        elt = match.group(4)
        stride = match.group(5)
        if not size:
            size = "3" # UnitVec
        return (vec, size, elt, stride)
        
    def instantiate_vecs(self, sizes, strides):
        # Small vectors
        # instantiate classes
        for type_str in self.iter_vec_row_names(sizes, strides):
            self.template_instantiation_string += \
                "template class SimTK::%s;\n" % type_str
        # typedef aliases
        for type_str in self.iter_vec_row_names(sizes, strides):
            (vec, size, elt, stride) = self.get_vec_components(type_str)
            alias = vec
            alias += size
            if elt == 'negator<double>':
                alias = "Neg" + alias
            if stride != '1':
                alias = "%s_%s" % (alias, stride)
            self.aliases_string += \
                "    typedef %s %s;\n" % (type_str, alias)
        # Unary methods
        for type_str in self.iter_vec_row_names(sizes, strides):
            self.write_vec_unary_methods(type_str)
        # Binary methods on Vec<...
        for type_str1 in self.iter_vec_names(sizes, strides):
            for type_str2 in self.iter_vec_names(sizes, strides):
                self.write_vec_binary_methods(type_str1, type_str2)
        # Binary methods on Row<...
        for type_str1 in self.iter_row_names(sizes, strides):
            for type_str2 in self.iter_row_names(sizes, strides):
                self.write_vec_binary_methods(type_str1, type_str2)
        # for type_str in self.iter_vec_names(sizes, strides):
        #     self.write_mat_vec_multiplies(type_str)
        # for type_str in self.iter_row_names(sizes, strides):
        #     self.write_row_mat_multiplies(type_str)
        # inner and outer products
        for type_str1 in self.iter_vec_names(sizes, strides):
            for type_str2 in self.iter_row_names(sizes, strides):
                self.write_vec_product_methods(type_str1, type_str2)

    def write_vec_unary_methods(self, type_str):
        self.instantiate_method_string += """//
        { // %s unary methods
            %s vec(1, 2, 3);
            // scalar multiply
            vec * 5.0; 3.0 * vec; vec *= 2.0;
            // and divide
            vec / 5.0; vec /= 2.0;
            // convert to string
            std::cout << vec;
        }
        """ % (type_str, type_str)
        
    def write_mat_vec_multiplies(self, type_str):
        (vec, size, elt, stride) = self.get_vec_components(type_str)
        if size == "3":
            self.instantiate_method_string += """//
        { // matrix * %s
            %s vec(1);
            // Mat33() * vec;
            // (-Mat33()) * vec;
            // (~Mat33()) * vec;
            // (-~Mat33()) * vec;
            // Transform() * vec;
            // ~Transform() * vec;
            // Rotation() * vec;
            // ~Rotation() * vec;
        }
        """ % (type_str, type_str)
        if size == "4":
            self.instantiate_method_string += """//
        { // matrix * %s
            %s vec(1);
            // Transform() * vec;
            // ~Transform() * vec;
        }
        """ % (type_str, type_str)

    def write_row_mat_multiplies(self, type_str):
        # Row3 only
        (vec, size, elt, stride) = self.get_vec_components(type_str)
        if size != "3":
            return
        self.instantiate_method_string += """//
        { // %s * matrix
            %s vec(1);
            vec * Mat33();
            vec * (-Mat33());
            vec * (~Mat33());
            vec * (-~Mat33());
            vec * Rotation();
            vec * ~Rotation();
        }
        """ % (type_str, type_str)

    def write_beginning(self, file):
        # beginning of instantiations.h file
        file.write("""
#ifndef PYSIMTK_SIMTKCOMMON_INSTANTIATIONS_H
#define PYSIMTK_SIMTKCOMMON_INSTANTIATIONS_H

#include <iostream>
#include "SimTKcommon/Scalar.h"         // self-contained
#include "SimTKcommon/SmallMatrix.h"    // includes Scalar.h
#include "SimTKcommon/Orientation.h"    // includes SmallMatrix.h

using namespace SimTK;

""")

    def write_end(self, file):
        # End of instantiation.h file
        file.write("""

#endif // PYSIMTK_SIMTKCOMMON_INSTANTIATIONS_H
    \n""")    

    def write_vec_binary_methods(self, type1_str, type2_str):
        self.instantiate_method_string += """//
        { // %s, %s binary methods
            %s vec1(1);
            %s vec2(2);
            // addition
            vec1 + vec2;
            // subtraction
            vec1 - vec2;
        }\n""" % (type1_str, type2_str, type1_str, type2_str)

    def write_cross_product_methods(self, type1_str, type2_str):
        self.instantiate_method_string += """//
        { // %s, %s cross product
            %s vec1(1);
            %s vec2(2);
            vec1 %% vec2;
        }\n""" % (type1_str, type2_str, type1_str, type2_str)

    def write_vec_product_methods(self, type1_str, type2_str):
        self.instantiate_method_string += """//
        { // %s, %s inner and outer products
            %s vec1(1);
            %s vec2(2);
            vec1 * vec2;
            vec2 * vec1;
        }\n""" % (type1_str, type2_str, type1_str, type2_str)

    def iter_vec_names(self, sizes, strides):
        """
        Generates class names for SimTK small vecs of particular sizes and strides.
        
        Returns a series of strings, e.g. "Vec<4, negator<double>, 4>"
        """
        for size in sizes:
            for stride in strides:
                for elt in ('double', 'negator<double>'):
                    # Single class method instantiations
                    type_str = "Vec<%s, %s, %s>" % (size, elt, stride)
                    yield type_str
                if size == 3 and stride in [1, 3]:
                    yield "UnitVec<double, %s>" % stride
                    
    def iter_row_names(self, sizes, strides):
        """
        Like iter_vec_names, but "Row<..." instead of "Vec<..."
        """
        for size in sizes:
            for stride in strides:
                for elt in ('double', 'negator<double>'):
                    # Single class method instantiations
                    type_str = "Row<%s, %s, %s>" % (size, elt, stride)
                    yield type_str
                if size == 3 and stride in [1, 3]:
                    yield "UnitRow<double, %s>" % stride

    def iter_vec_row_names(self, sizes, strides):
        for s in self.iter_vec_names(sizes, strides):
            yield s
        for s in self.iter_row_names(sizes, strides):
            yield s

    def iter_vec_of_vec_names(self, sizes):
        for size in sizes:
            if size == 3:
                yield "Vec<2, Vec<%s, double, 1>, 1>" % (size,) # SpatialVec...

if __name__ == '__main__':
    instantiator = Instantiator('instantiations.h')
    instantiator.generate_instantiations()
