import febio, FebPlt, parse_cfg, parse_feb, hex_cent
import sys, string, pickle
import numpy as np
from scipy import spatial
from scipy.optimize import fmin_l_bfgs_b

def main(cfg):
    print 'Parsing configuration file'
    cfgcontent = parse_cfg.parse_cfg(cfg)
    meshes = []
    for i in cfgcontent['MeshFiles']:
        print '...Parsing mesh file: ' + i
        meshes.append(parse_feb.Mesh(i))
        
    results = []
    for i in cfgcontent['PlotFiles']:
        print '......Parsing plot file:' + i
        results.append(FebPlt.FebPlt(i))
    if string.lower(cfgcontent['Steps'][0]) == 'all':
        steps = []
        for r in results:
            nid = r.NodeData.keys()[0]
            for i in r.NodeData[nid].keys():
                if i in ['displacement','nodal fluid flux','effective fluid pressure']:
                    steps.append(r.NodeData[nid][i].shape[0])
                    break
                else:
                    continue
    else:
        steps = map(int,cfgcontent['Steps'][0].split(','))

    cmodel = cfgcontent['CoupledModel'][0].split(',')
    cmesh = febio.MeshDef(cmodel[0],'abq')
    end = np.array(map(float,cfgcontent['CellEndPosition'][0].split(',')))
    start = np.array(map(float,cfgcontent['CellStartPosition'][0].split(',')))
    trans_v = end-start 

    dmy = np.zeros((len(cmesh.nodes),3),float)
    for i in xrange(len(cmesh.nodes)):
        dmy[i,:] = np.array(cmesh.nodes[i][1:])
    dmy = dmy + trans_v
    for i in xrange(len(cmesh.nodes)):
        cmesh.nodes[i][1:] = list(dmy[i,:])

    nset = cmesh.nsets[cmodel[1]]
    pi = np.zeros((len(nset),3),float)
    load_curves = {}
    for i in xrange(len(nset)):
        pi[i,:] = cmesh.nodes[nset[i]-1][1:]
        load_curves[i] = {'nid': cmesh.nodes[nset[i]-1][0],'t': None,'x': [], 'y': [], 'z': [], 'p': []} 

    # element/nodes
    m_elms = []
    m_nodes = []
    for i in meshes:
        N = len(i.elements)
        elms = np.zeros((N,8),int)
        for j in xrange(N):
            elms[j,:] = i.elements[j][2:]
        m_elms.append(elms)
        N = len(i.nodes)
        nodes = np.zeros((N,3),float)
        for j in xrange(N):
            nodes[j,:] = i.nodes[j][1:]
        m_nodes.append(nodes)

    #element centroids
    ecents = []
    for i in xrange(len(meshes)):
        N = m_elms[i].shape[0]
        el = np.asfortranarray(m_elms[i])
        n = np.asfortranarray(m_nodes[i])
        cent = np.asfortranarray(np.zeros((N,3),float))
        hex_cent.hexcalc(el,n,cent)
        ecents.append(cent)

    # build KD trees for fast searching
    trees = []
    for t in ecents:
        trees.append(spatial.KDTree(t))

    # Find which element point lies within
    phomes = []
    for i in xrange(len(meshes)):
        phome = np.zeros(len(nset),int)
        for j in xrange(len(nset)):
            phome[j] = econtain(pi[j,:],trees[i],m_elms[i],m_nodes[i])
        phomes.append(phome)

    etas = []
    bounds = [(-1.0,1.0),(-1.0,1.0),(-1.0,1.0)]
    for i in xrange(len(meshes)):
        phome = phomes[i]
        eta = np.zeros((len(phome),3),float)
        for j in xrange(len(phome)):
            p = pi[j,:]
            nids = m_elms[i][phome[j],:]
            x = np.zeros((8,3),float)
            for k in xrange(8):
                x[k,:] = nodes[nids[k]-1,:]
            eta[j,:] = fmin_l_bfgs_b(obj,eta[j,:],args=(p,x),approx_grad=True,bounds=bounds)[0]
        etas.append(eta)

    # Prep work is done, start interpolating variables
    variables = cfgcontent['FieldVariables']
    cnt = 0
    for r in results:
        for i in load_curves.keys():
            load_curves[i]['t'] = r.TIME
        for i in variables:
            print '......Interpolating '+i+' for plot file '+str(cnt)
            if string.lower(cfgcontent['Steps'][0]) == 'all':
                time_steps = range(steps[cnt])
            else:
                time_steps = steps
            for k in time_steps:
                ids = r.NodeData.keys()
                dat = {}
                for e in ids:
                    try:
                        dat[e] = r.NodeData[e][i][k,:]
                    except:
                        continue
                rows,cols = pi.shape
                #rbfi = np.zeros((rows,len(dat[dat.keys()[0]])),float)
                L = len(dat[dat.keys()[0]])
                if i == 'effective fluid pressure':
                    dofs = ['p']
                elif i == 'displacement':
                    dofs = ['x','y','z']
                for l in xrange(L):
                    dof = dofs[l]
                    for p in xrange(rows):
                        elms = m_elms[cnt]
                        eid = phomes[cnt][p]
                        d = np.zeros(8,float)
                        for v in xrange(8):
                            nid = elms[eid,v]
                            d[v] = dat[nid][l]
                        load_curves[p][dof].append(interp(etas[cnt][p],d))
        cnt += 1

        print load_curves[0]

        fid = open(string.replace(cmodel[0],'.inp','_lc.pkl'),'wb')
        pickle.dump(load_curves,fid)
        fid.close()


def interp(eta,v):
    r = [-1,1,1,-1,-1,1,1,-1]
    s = [-1,-1,1,1,-1,-1,1,1]
    t = [-1,-1,-1,-1,1,1,1,1]
    vi = 0.0
    for i in xrange(8):
        vi += (1+r[i]*eta[0])*(1+s[i]*eta[1])*(1+t[i]*eta[2])*v[i]
    vi /= 8.0
    return vi

def obj(eta,p,x):
    xeta = np.zeros((1,3),float)
    r = [-1,1,1,-1,-1,1,1,-1]
    s = [-1,-1,1,1,-1,-1,1,1]
    t = [-1,-1,-1,-1,1,1,1,1]
    for i in xrange(8):
        xeta[0,:] += (1+r[i]*eta[0])*(1+s[i]*eta[1])*(1+t[i]*eta[2])*x[i,:]
    xeta /= 8.0
    return np.linalg.norm(xeta-p)

def econtain(p,tree,elms,nodes):
    d,eids = tree.query(p,k=8)
    for eid in eids:
        nids = elms[eid,:]
        x = np.zeros((8,3),float)
        for i in xrange(8):
            x[i,:] = nodes[nids[i]-1,:]
        h = spatial.Delaunay(x)
        case = h.find_simplex(p)>=0
        if case:
            return eid

if __name__ == '__main__':
    main(sys.argv[-1])
