import string

class Mesh(object):
    '''Parses abaqus input file and stores information in an instance of Mesh
    members:
        nodes - list of node ids (int) and coordinates (float)
        elements - list of element type (str), id (int), and node order (int)
        elsets - dictionary with keys as set name (str), items are list of element ids
        nsets - same structure as elsets
        fsets - same structure as elsets but items are surface element definitions
            [[face type (str), face id (int), node order (int)
    '''
    facetID = 1
    def __init__(self,mesh,scale=1.0):
        self.scale = scale
        self.fname = mesh
        self.nodes = []
        self.elements = []
        self.elsets = {}
        self.nsets = {}
        self.fsets = {}
        self.__parse_mesh()

    def __parse_mesh(self):
        fid = open(self.fname,'r')
        lines = map(string.rstrip,fid.readlines())
        fid.close()
        content = {}
        kywd =''
        kywds = []
        repeats = 0
        for line in lines:
            if '**' in line:
                continue
            elif '*' in line:
                kywd = line.lower()
                try:
                    content[kywd]
                    kywd = kywd+str(repeats)
                    repeats += 1
                    content[kywd] = []
                    kywds.append(kywd)
                    continue
                except:
                    content[kywd] = []
                    kywds.append(kywd)
                    continue
            dmy = line.split(',')
            try:
                dmy.remove('')
            except:
                pass
            content[kywd].append(map(string.lower,dmy))

        for i in kywds:
            if 'node' in i:
                for n in content[i]:
                    self.nodes.append([int(n[0])]+map(self.__scaleNode,map(float,n[1:])))
            elif 'element' in i:
                dmy = i.partition('type=')[-1]
                if 'c3d8' in dmy:
                    etype = 'hex8'
                elif 'c3d4' in dmy:
                    etype = 'tet4'
                elif 'c3d6' in dmy:
                    etype = 'penta6'
                elif 'cpe3' in dmy:
                    etype = 'tri3'
                elif 'cpe4' in dmy:
                    etype = 'quad4'
                else:
                    print "WARNING: Element type "+dmy+" is not supported. This section will be ignored..."
                    continue
                for e in content[i]:
                    self.elements.append([etype]+map(int,e))
            elif 'elset' in i:
                setname = i.partition('=')[-1]
                self.elsets[setname] = []
                for line in content[i]:
                    for e in line:
                        self.elsets[setname].append(int(e))
            elif 'nset' in i:
                setname = i.partition('=')[-1]
                self.nsets[setname] = []
                for line in content[i]:
                    for n in line:
                        self.nsets[setname].append(int(n))
            elif 'surface' in i:
                setname = i.partition('=')[-1]
                self.fsets[setname] = []
                face_def = {
                            'hex8': {'s1': ['quad4',0,1,2,3], 's2': ['quad4',4,7,6,5], 's3': ['quad4',0,4,5,1], 's4': ['quad4',1,5,6,2], 's5': ['quad4',2,6,7,3],'s6': ['quad4',3,7,4,0]},
                            'tet4': {'s1': ['tri3',0,1,2], 's2': ['tri3',0,3,1], 's3': ['tri3',1,3,2], 's4': ['tri3',2,3,0]},
                            'penta6': {'s1': ['tri3',0,1,2], 's2': ['tri3',3,5,4], 's3': ['quad4',0,3,4,1], 's4': ['quad4',1,4,5,2], 's5': ['quad4',2,5,3,0]},
                            }
            
                for line in content[i]:
                    eid = int(line[0])
                    try:
                        elm = self.elements[eid-1]
                    except:
                        print eid-1
                    node_order = face_def[elm[0]][line[1]][1:]
                    stype = face_def[elm[0]][line[1]][0]
                    dmy = [stype,Mesh.facetID]
                    for n in node_order:
                        dmy.append(elm[n+2])
                    self.fsets[setname].append(dmy)
                    Mesh.facetID += 1
            else:
                print 'Mesh parser ignored keyword line: '+i
    def __scaleNode(self,n):
        return n*self.scale
