'''
Created on 2013-5-7 by Scott Sibole

Purpose - reads in a mesh in either Medit .mesh or Abaqus .inp format created by either tetgen or netgen, as well as surfaces used to create that mesh.  Using the surfaces, it determines material regions and writes an Abaqus input file with element, node, and face sets.

Inputs - 3 arguments: 
    - a configuration file listing STL files for surfaces on each line (order these systematically so internal surfaces are evaluated first
    - mesh file created by tetgen or netgen in .inp or .mesh format
    - desired output format:
        - abq
        - hdf5

Output - Abaqus input file with name [configuration file name].inp containing element, node, and face sets

example usage: python make_sets.py example.cfg example.1.mesh  
'''

import sys, string, os, vtk, re, h5py, gts, subprocess
import numpy as np

def main(cfg,msh,output):
    #read in surfaces from configuration file
    fid = open(cfg,'r')
    regions = map(string.rstrip,fid.readlines())
    fid.close()
    print 'Reading in surfaces'
    #Append surfaces to list
    surfaces = [] 
    setnames = []
    for r in regions:
        r = r.split(',')
        reader = vtk.vtkSTLReader()
        reader.SetFileName(r[0])
        reader.Update()
        mesh = reader.GetOutput()
        surfaces.append(mesh)
        setnames.append(r[1])
    print '...Parsing mesh'
    content = parse_mesh(msh)
    nodes = content['Nodes']
    elms = content['Elements']
    
    nelms = len(elms)
    nnodes = len(nodes)
    
    matids = {}
    for i in setnames:
        matids[i] = []
    
    print '......Finding element centroids'
    points = vtk.vtkPoints()
    vertices = vtk.vtkCellArray()
    for j in xrange(nelms):
        e = elms[j]
        c = []  #centroid of element
        for i in xrange(3):
            c.append((nodes[e[0]-1][i]+nodes[e[1]-1][i]+nodes[e[2]-1][i]+nodes[e[3]-1][i])/4.)
        id = points.InsertNextPoint(c)
        vertices.InsertNextCell(1)
        vertices.InsertCellPoint(id)
        
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(points)
    polydata.SetVerts(vertices)
    
    '''
    #Visualize Point Cloud
    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInput(polydata)
    actor = vtk.vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetPointSize(5)
    
    renderer = vtk.vtkRenderer()
    renderWindow = vtk.vtkRenderWindow()
    renderWindow.AddRenderer(renderer)
    renderWindowInteractor = vtk.vtkRenderWindowInteractor()
    renderWindowInteractor.SetRenderWindow(renderWindow)
    
    renderer.AddActor(actor)
    renderWindow.Render()
    renderWindowInteractor.Start()
    '''
    
    print '.........Determining element sets'
    already_enclosed = {}
    cnt = 0
    for s in surfaces:
        enclose_pts = vtk.vtkSelectEnclosedPoints()
        enclose_pts.SetInput(polydata)
        enclose_pts.SetTolerance(0.0000000000000000001) #set tolerance extremely tight
        enclose_pts.SetSurface(s)
        enclose_pts.Update()
        inside_arr = enclose_pts.GetOutput().GetPointData().GetArray('SelectedPoints')
        for i in xrange(inside_arr.GetNumberOfTuples()):
            if inside_arr.GetComponent(i, 0):
                try:
                    already_enclosed[i]
                except:
                    matids[setnames[cnt]].append(i+1)
                    already_enclosed[i] = 1
        cnt += 1
        del enclose_pts
    
    print '\nVerify all elements were assigned a set: numbers should be the same...'
    print len(elms),len(already_enclosed)
    print('\n')
    
    print 'Determining node sets'
    # determine nodes that are on an internal material boundary
    node_assignments = {}
    for i in xrange(nnodes):
        node_assignments[i+1] = []
        
    for i in matids.keys():
        for j in matids[i]:
            e = elms[j-1]
            for n in e:
                if i in node_assignments[n]:
                    continue
                else:
                    node_assignments[n].append(i)
                
    nsets = {}        
    for ns in node_assignments.keys():
        mats = node_assignments[ns]
        if len(mats) > 1:
            if 'ecm' in mats[0]:    #a PCM set
                setid = 'n'+mats[1]
            elif 'ecm' in mats[1]:  #a PCM set
                setid = 'n'+mats[0]
            elif 'pcm' in mats[0]:  #a Cell set
                setid = 'n'+mats[1]
            else:                   #a Cell set
                setid = 'n'+mats[0] 
            try:
                nsets[setid].append(ns)
            except:
                nsets[setid] = []
                nsets[setid].append(ns)
        else:
            continue
    
    # Find the outer surface nodes
    points = vtk.vtkPoints()
    vertices = vtk.vtkCellArray()
    for n in nodes:
        id = points.InsertNextPoint(n)
        vertices.InsertNextCell(1)
        vertices.InsertCellPoint(id)
    
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(points)
    polydata.SetVerts(vertices)
    polydata.Update()
    
    '''
    vtkTransform = vtk.vtkTransform()
    vtkTransform.Scale(1.02,1.02,1.02)
    transF = vtk.vtkTransformPolyDataFilter()
    transF.SetInput(polydata)
    transF.SetTransform(vtkTransform)
    transF.Update()
    '''
    f = regions[-1].split(',')[0]
    offset_mesh(f,-1.)
    dmy = string.replace(f,'.stl','_offset.stl')
    reader = vtk.vtkSTLReader()
    reader.SetFileName(dmy)
    reader.Update()
    surf = reader.GetOutput()
    os.remove(dmy)
    
    enclose_pts = vtk.vtkSelectEnclosedPoints()
    enclose_pts.SetInput(polydata)
    enclose_pts.SetTolerance(0.00000000001) #set tolerance extremely tight
    enclose_pts.SetSurface(surf)
    enclose_pts.Update()
    inside_arr = enclose_pts.GetOutput().GetPointData().GetArray('SelectedPoints')
    nsets['necm'] = []
    for i in xrange(inside_arr.GetNumberOfTuples()):
        if not inside_arr.GetComponent(i, 0):
            nsets['necm'].append(i)
            
    # Figure out faces corresponding to surface node sets and build sets for these
    fsets = {}
    for i in nsets.keys():
        fsetid = string.replace(i,'n','f')
        fsets[fsetid] = []
        es = matids[string.replace(i,'n','')]
        parent_elms = {}
        ns = nsets[i]
        for e in es:
            for n in elms[e-1]:
                if n in ns:
                    try:
                        parent_elms[e].append(n)
                    except:
                        parent_elms[e] = []
                        parent_elms[e].append(n)
        
        for e in parent_elms.keys():
            subnodes = parent_elms[e]
            elm = elms[e-1]
            ind = []
            try:
                ind.append(elm.index(subnodes[0]))
                ind.append(elm.index(subnodes[1]))
                ind.append(elm.index(subnodes[2]))
                if 0 in ind and 1 in ind and 2 in ind:
                    face = [str(e), 'S1']
                elif 0 in ind and 1 in ind and 3 in ind:
                    face = [str(e), 'S2']
                elif 1 in ind and 2 in ind and 3 in ind:
                    face = [str(e), 'S3']
                else:
                    face = [str(e), 'S4']
                fsets[fsetid].append(face)
            except:
                pass
    
    
    if 'abq' in output.lower():
        print '...Writing abaqus input file: '+string.replace(cfg,'.cfg','.inp')
        write_abq(cfg,nodes,elms,matids,nsets,fsets)
    else:
        print '...Writing hdf5 binary file: '+string.replace(cfg,'.cfg','.hdf5')
        write_hdf5(cfg,nodes,elms,matids,nsets,fsets)
        
def offset_mesh(surf,offset):
    nfile = string.replace(surf,'.stl','.gts')
    subprocess.call('stl2gts < '+surf+' > '+nfile, shell=True)
    fid = open(nfile,'r')
    s = gts.read(fid)
    fid.close()
    os.remove(nfile)
    vert = s.vertices()
    offsets = []
    for v in vert:
        tri = v.faces()
        an = np.zeros((1,3),float)
        cnt = 0
        for t in tri:
            n = np.array(t.normal())
            n = n/np.linalg.norm(n)
            an = an + n
            cnt += 1
        
        an = an/float(cnt)
        an = an/np.linalg.norm(an)
        offsets.append(float(offset)*an)
    cnt = 0
    for v in vert:
        v.translate(offsets[cnt][0,0],offsets[cnt][0,1],offsets[cnt][0,2])
        cnt += 1
    oname =  nfile.replace('.gts','_offset.gts')   
    fid = open(oname,'w')
    s.write(fid)
    fid.close()
    subprocess.call('gts2stl < '+oname+' > '+string.replace(oname,'.gts','.stl'),shell=True)
    os.remove(oname)
    
def parse_mesh(msh):
    #read in nodes and elements from file in .mesh format
    fid = open(msh,'r')
    lines = map(string.rstrip,fid.readlines())
    fid.close()
    if '.mesh' in msh:
        #extract elements and nodes
        node_start = lines.index('Vertices')
        elm_start = lines.index('Tetrahedra')
        nnodes = int(lines[node_start+1])
        nelms = int(lines[elm_start+1])
        tnodes = lines[node_start+2:node_start+2+nnodes]
        telms = lines[elm_start+2:elm_start+2+nelms]
        #split strings in temporary node list into list of x,y,z floats
        nodes = []
        for node in tnodes:
            node = re.sub('\.','A',node)
            node = re.sub('-','B',node)
            node = re.split('\W+',node)
            node = node[0:3]
            for i in xrange(len(node)):
                node[i] = node[i].replace('A','.')
                node[i] = node[i].replace('B','-')
            nodes.append(map(float,node))
            
        #split strings in temporary element list into list of integers
        elms = []
        for elm in telms:
            elm = re.split('\W+',elm)
            dmy = []
            for e in elm:
                if e != '' and e != '0':
                    dmy.append(int(e))    
            elms.append(dmy)
            
        content = {}
        content['Nodes'] = nodes
        content['Elements'] = elms
        return content
        
    elif '.inp' in msh:
        content = {}
        content['Nodes'] = []
        content['Elements'] = []
        for line in lines:
            if '*node' in string.lower(line):
                kywrd = 'Nodes'
            elif '*element' in string.lower(line):
                kywrd = 'Elements'
            else:
                content[kywrd].append(line.split(','))

        nodes = []
        for n in content['Nodes']:
	        nodes.append(map(float,n[1:]))
        elms = []
        for e in content['Elements']:
            elms.append(map(int,e[1:]))
        
        content['Nodes'] = nodes
        content['Elements'] = elms
        return content
        
        
    
def write_abq(f,nodes,elms,matids,nsets,fsets):
    fid = open(string.replace(f,'.cfg','.inp'),'w')
    
    fid.write('*NODE\n')
    cnt = 1
    for n in nodes:
        n = map(str,n)
        fid.write(str(cnt)+','+string.join(n,',')+'\n')
        cnt += 1
    
    fid.write('*ELEMENT,TYPE=C3D4\n')
    cnt = 1
    for e in elms:
        e = map(str,e)
        fid.write(str(cnt)+','+string.join(e,',')+'\n')
        cnt += 1

    for i in matids.keys():
        fid.write('*ELSET,ELSET='+str(i)+'\n')
        for j in xrange(len(matids[i])):
            if (j%10 > 0 or j==0) and j!=len(matids[i])-1:
                fid.write(str(matids[i][j])+',')
            elif j==len(matids[i])-1:
                fid.write(str(matids[i][j])+'\n')
            else:
                fid.write(str(matids[i][j])+',\n')
                
    for i in nsets.keys():
        fid.write('*NSET,NSET='+str(i)+'\n')
        for j in xrange(len(nsets[i])):
            if (j%10 > 0 or j==0) and j!=len(nsets[i])-1:
                fid.write(str(nsets[i][j])+',')
            elif j==len(nsets[i])-1:
                fid.write(str(nsets[i][j])+'\n')
            else:
                fid.write(str(nsets[i][j])+',\n')
                
    for i in fsets.keys():
        fid.write('*SURFACE,NAME='+str(i)+'\n')
        for j in fsets[i]:
            fid.write(string.join(j,',')+'\n')    
        
    fid.close()
    
def write_hdf5(f,nodes,elms,matids,nsets):

    narray = np.zeros((len(nodes),3),np.dtype('f4'))
    for i in xrange(len(nodes)):
        narray[i,0] = nodes[i][0]
        narray[i,1] = nodes[i][1]
        narray[i,2] = nodes[i][2]
    
    earray = np.zeros((len(elms),4),np.dtype('u4'))
    for i in xrange(len(elms)):
        earray[i,0] = elms[i][0]
        earray[i,1] = elms[i][1]
        earray[i,2] = elms[i][2]
        earray[i,3] = elms[i][3]
          
    fid = h5py.File(string.replace(f,'.cfg','.hdf5'),'w')
    fid.create_dataset("Nodes",data=narray)
    fid.create_dataset("Elements",data=earray)
    
    esetgrp = fid.create_group("Elsets",)
    nsetgrp = fid.create_group("Nsets")
    
    #Add node set data to subgroup: nsetgrp
    for i in nsets.keys():
        ns = nsets[i]
        nsarray = np.zeros((len(ns),1),np.dtype('u4'))
        for n in xrange(len(ns)):
            nsarray[n,0] = ns[n]
        nsetgrp.create_dataset(i,data=nsarray)
    #Add element set data to subgroup: esetgrp    
    for i in matids.keys():
        es = matids[i]
        esarray = np.zeros((len(es),1),np.dtype('u4'))
        for e in xrange(len(es)):
            esarray[e,0] = es[e]
        esetgrp.create_dataset(i,data=esarray)
        
    fid.close()
    
if __name__ == '__main__':
    main(sys.argv[-3],sys.argv[-2],sys.argv[-1])
