from lxml import etree as et
import subprocess32
import sys
import numpy as np
import time
import math
import matplotlib.pyplot as plt
import os

'''New version of script to be submitted as its own python job, since we cant run python scripts on the login node.
Submit through a text file as in any other job. Be sure to edit the paths and the file names for each new model. 
Filenames and paths need to be set up the same way I had them, as in the commented portion of the main argument. Otherwise
this code probably needs to be edited.
'''

def lin_fit(self, x, y):
    '''function from erica's subject_force_displacement_plots.py'''

    #plt.rcParams.update({'font.size' :12, 'lines.linewidth' :1})
    #plt.figure(figsize=(9, 11))
    #plt.plot(x, y, 'o', markersize=5)
    #ax = plt.gca()
    #[i.set_linewidth(1) for i in ax.spines.itervalues()]
    #ax.tick_params(length=1)

    #Fits a linear fit of the form mx to the data
    #A = np.vstack([x]).T

    m, _, _, _ = np.linalg.lstsq(x, y, rcond=None)

    #print m
    y_fit = m * x
    #plt.plot(x, y_fit)
    y_bar = np.average(y)
    SS_tot = np.sum((y - y_bar) ** 2)
    SS_res = np.sum((y - y_fit) ** 2)
    R_sqr = 1 - SS_res / SS_tot
    #plt.ylabel('Force (N)')
    #plt.xlabel('Displacement (mm)')
    #textstr = r'$R^2=%.2f$' % (R_sqr,) + '\n' + r'$m=%.2f$' %(m[0], )
    #props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    #ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=15,verticalalignment='top',bbox=props)
    #plt.ylim([-0.5 ,8.5])
    #plt.xlim([-0.5 ,12.5])
    #plt.xlabel('Displacement (mm)')
    #plt.ylabel('Force (N)')
    # plt.subplots_adjust(bottom=.1)
    # plt.tight_layout()
    #plt.savefig \
    #    (os.path.join('/home/doherts/Documents/MULTIS/Registration_008/Figures/', os.path.basename(self) +'.png'), dpi=300)
        # plt.close()

        # plt.tight_layout()
        # plt.show()
    return float(m[0]), float(R_sqr)

def get_Disp_XMLData(RegistrationXML):
    '''read displacement data from registration xml files'''
    f = open(RegistrationXML, 'r')
    tree = et.parse(f)
    root2 = tree.getroot()



    limb = os.path.basename(RegistrationXML).split('_')
    limbDict = {'UL': 'UpperLeg', 'LL': 'LowerLeg', 'UA': 'UpperArm', 'LA': 'LowerArm'}

    data = root2.find(limbDict.get(limb[-3])).find("CentralAnterior").find("Indentation").findall("USPosition")

    Xdisp = []
    Ydisp = []
    Zdisp = []
    TotalDisp = np.ones((0, 1))
    for i in xrange(len(data)):
        Xdisp.append(float(data[i].find('x').get('value')))
        Ydisp.append(float(data[i].find('y').get('value')))
        Zdisp.append(float(data[i].find('z').get('value')))

    Xdisp[:] = [x - Xdisp[0] for x in Xdisp]
    Ydisp[:] = [y - Ydisp[0] for y in Ydisp]
    Zdisp[:] = [z - Zdisp[0] for z in Zdisp]
    # find displacement magnitude
    for i in xrange(len(Xdisp)):
        TotalDisp = np.append(TotalDisp,
                              [[math.sqrt((Xdisp[i] * 1000) ** 2 + (Ydisp[i] * 1000) ** 2 + (Zdisp[i] * 1000) ** 2)]],
                              axis=0)

    return TotalDisp

def get_ForceTime_XMLData(manThickXML):
    '''read force data from manthick xml files'''

    f = open(manThickXML, 'r')
    tree = et.parse(f)
    root1 = tree.getroot()
    data = root1.find('Subject').find('Source').findall('Frame')

    Fx = []
    Fy = []
    Fz = []
    Time = np.ones((0, 1))
    Force = np.ones((0, 1))

    for i in xrange(len(data)):
        Fx.append(float(data[i].find('Forces').find('Fx').text))
        Fy.append(float(data[i].find('Forces').find('Fy').text))
        Fz.append(float(data[i].find('Forces').find('Fz').text))
        Force = np.append(Force, [[math.sqrt(Fx[i] ** 2 + Fy[i] ** 2 + Fz[i] ** 2)]], axis=0)
        Time = np.append(Time, [[float(data[i].find('Time').get('value'))]], axis=0)

    return Force, Time

def get_ForceDisp_LogData(logFile, MaxDispMagnitude):
    '''Return numpy array of force and displacement based on febio log file'''
    log = open(logFile, 'r')
    lines = log.readlines()
    data = np.ones((0, 2))
    TimeArray = np.ones((0, 1))
    ForceArray = np.ones((0, 1))
    for i, line in enumerate(lines):
        if line[:-1] == 'Data = Fx;Fy;Fz':
            time = float(lines[i - 1].split(' ')[-1])
            forces = lines[i + 1]  # forces will end up being the last entry
            forces = [float(f) for f in forces.split()]
            force = (forces[1] ** 2 + forces[2] ** 2 + forces[3] ** 2) ** 0.5
            if time >= 1:
                data = np.append(data, [[time, force]], axis=0)
                TimeArray = np.append(TimeArray, [[time]], axis=0)
                ForceArray = np.append(ForceArray, [[force]], axis=0)
            # print >> open(path + 'Py_Script.out', 'w'), '\n',   time, '\t\t', force

    DispArray = np.abs((TimeArray - 1) * MaxDispMagnitude)
    # Move data over so duplicate 0 force timepoints arent used which may influence line fit
    # print >> open(path + 'Py_Script.out', 'w'), '\n',   len(ForceArray)
    # print >> open(path + 'Py_Script.out', 'w'), '\n',   np.count_nonzero(ForceArray)

    NumZeros = len(ForceArray) - np.count_nonzero(ForceArray)
    #print >> open(loggingFile, 'a+'), '\n',   NumZeros, type(NumZeros)

    ForceArray = np.delete(ForceArray, range(0, NumZeros - 1), 0)
    TimeArray = np.delete(TimeArray, range(0, NumZeros - 1), 0)
    DispArray = np.delete(DispArray, range(0, NumZeros - 1), 0)

    TimeArray = TimeArray - TimeArray[0]
    DispArray = DispArray - DispArray[0]

    return ForceArray, DispArray, TimeArray

def ChangeFebio(filename, name, finding, replacing, path):
	"""Read febio file (.feb is an .xml file)
		find the elements in finding list and replace them with the new values in replacing list"""
	print >> open(path + 'Py_Script.out','a') ,  "ChangingFEBIO"
	if   isinstance(finding,   str):	finding   = [finding]
	if   isinstance(replacing, str):	replacing = [replacing]
	elif isinstance(replacing, float):	replacing = [replacing]
	replaceDict = dict(zip(finding,replacing))
	print >> open(path + 'Py_Script.out','a') ,  replaceDict
	#print >> open(path + 'Py_Script.out','a') ,  filename, os.path.getsize(filename)
	print >> open(path + 'Py_Script.out','a') ,  type(filename)
	f = open(filename, 'r')
	tree=et.parse(f)
	#tree.parse(r'/home/doherts/lustre/MULTIS/006_LA/006LA_Quad.feb')
	print >> open(path + 'Py_Script.out','a') ,  "Parsed"
	root = tree.getroot()
	materials = root.find('Material').findall('material')
	#print >> open(path + 'Py_Script.out', 'w'), '\n',   "ChangingFEBIO"
	for mat in materials:
		if name == mat.get('name'):
			for find,replace in replaceDict.iteritems():
				if mat.get('type') == "prestrain elastic":	mat.find('elastic').find(find).text = str(replace)
				else:										mat.find(find).text = str(replace)
	tree.write(filename, xml_declaration=True, pretty_print =True)
	#rint >> open(path + 'Py_Script.out','a') ,
	print >> open(path + 'Py_Script.out','a') ,  "ChangedFebio"

def ReadFebioLog(logFile, expectedValue, epsilon):

	'''This function is deprecated in the HPC version of this code, but was kept as a reference
	Searches febio logfile for reaction force of probe - may one day be expanded to be more general
	to improve robustness of script will estimate the reaction force if the simulation does not fully converge'''
	log = open(logFile, 'r')
	lines = log.readlines()
	data = np.ones((0,2))
	for i,line in enumerate( lines ):
		if line[:-1] == 'Data = Fx;Fy;Fz':
			time =   float(lines[i-1].split(' ')[-1])
			forces = lines[i+1]		# forces will end up being the last entry
			forces = [float(f) for f in forces.split()]
			force = (forces[1]**2 + forces[2]**2 + forces[3]**2)**0.5
			if time >= 1:
				if (time*10).is_integer():	# this finds the MUST_POINTS which provide a better estimate than other times
					data = np.append(data, [[time,force]], axis=0)
					print >> open(path + 'Py_Script.out','a') ,  time, '\t\t', force
	if data[-1][0] == 2:
		complete = True
		ratio = force / expectedValue
		if 1.0-epsilon < ratio < 1.0+epsilon:	converged = True
		else:									converged = False
		print >> open(path + 'Py_Script.out','a') ,  'Converged.'
		print >> open(path + 'Py_Script.out','a') ,  'Force is {}'.format(force)
		return complete, converged, force
	elif data[-1][0] > 1:
		complete = False
		estimatedForce = (data[-1][1]-data[-2][1]) /  (data[-1][0]-data[-2][0])		# difference in forces / difference in times
		ratio = estimatedForce / expectedValue
		if 1.0-epsilon < ratio < 1.0+epsilon:	converged = True
		else:									converged = False
		print >> open(path + 'Py_Script.out','a') ,  'FAILED to converge.'
		print >> open(path + 'Py_Script.out','a') ,  'estimated force is {}'.format(estimatedForce)
		return complete, converged, estimatedForce
	else:
		print >> open(path + 'Py_Script.out','a') ,  'Prestraining did not converge! '

def CallLog(logFile, complete, converged, expected, simulated,matParam,path):
	"""Creates a logfile of the scripts process"""
	ratio = simulated / expected
	print >> open(path + 'Py_Script.out','a') ,  		  'Expected Slope:	{}\nSimulated Slope:	{}\nratio :  {}\n'.format(expected, simulated, ratio)
	logFile.write('Expected Slope:	{}\nSimulated Slope:	{}\nratio :  {}\n'.format(expected, simulated, ratio))
	if complete and converged:
		print >> open(path + 'Py_Script.out','a') ,  		  '\n\nSucessfully converged!!!\n'
		logFile.write('\n\nSucessfully converged!!!\n\n')
	elif converged:
		print >> open(path + 'Py_Script.out','a') ,  		  '\n\nSimulation expected to converge, but FAILED to run to completion.\n'
		logFile.write('\nSimulation expected to converge, but FAILED to run to completion.\n')
	#matParam = [m/ratio for m in matParam]
	print >> open(path + 'Py_Script.out','a') ,  	      "Latest Material parameters are:	{}".format(matParam)
	logFile.write("Latest Material parameters are:	{}\n\n".format(matParam))
	logFile.flush()

def ClipFebioForceDisp(FebioDisp, FebioForce,ExpDisp, ExpForce):
    #print 'Lower Bound of Force Clip: ', min(ExpForce)
    FebioForceClip = []
    FebioDispClip = []
    for x in FebioForce:
        if x >= min(ExpForce):
            FebioForceClip.append(x)


    FebioForceClip = np.array(FebioForceClip)
    #print 'Clipped Febio Force: ', FebioForceClip
    #print FebioDisp
    #print ExpDisp
    #print np.shape(FebioForceClip)

    TempFebioDispClip = np.reshape(np.delete(FebioDisp, range(0, len(FebioDisp) - len(FebioForceClip))), [-1, 1])
    TempFebioDispClip = TempFebioDispClip - TempFebioDispClip[0]
    #print np.shape(TempFebioDispClip)
    for x in TempFebioDispClip:
        if x<=max(ExpDisp):
            FebioDispClip.append(x)
        else:
            FebioDispClip.append(x)
            break
    FebioDispClip = np.array(FebioDispClip)
    FebioForceClip = np.reshape(np.delete(FebioForceClip,range(len(FebioDispClip),len(FebioForceClip))), [-1, 1])
    #print np.shape(FebioForceClip)


    return FebioDispClip, FebioForceClip


def RunInverseFEA(febioFile=None, RegistrationFile=None, ManThickFile=None, path=None, epsilon=0.025, febioCommand=None):
	""""""	
	if not febioCommand:
		febioCommand = 'sbatch'
		#print >> open(path + 'Py_Script.out', 'w'), '\n',   'Using {} to call febio'.format(febioCommand)
	#
	if not febioFile:	febioFile = path + '008UL_Quad_run1.feb'
	#print >> open(path + 'Py_Script.out', 'w'), '\n',   'Running {}'.format(febioFile)

	print >> open(path + 'Py_Script.out','a') ,  febioFile
	output = febioFile.replace('.feb', '.log')
	print >> open(path + 'Py_Script.out','a') ,  output
	convergeLog = open(febioFile.replace('.feb', '.dat'), 'a+')
	print >> open(path + 'Py_Script.out','a') ,  convergeLog

	currentInt = 1
	optimizeMaterialName = 'Flesh'
	matGuess = [0.01, 10]	# guess of material properties to change
	#
	complete = True;	converged = False
	while not (complete and converged):
		# first adjust febio file to original guess

		#Rename files as needed
		febioFile = str(febioFile).split('run')[0] + 'run{}.feb'.format(currentInt)
		print >> open(path + 'Py_Script.out', 'a'), febioFile, os.path.getsize(febioFile)
		output = str(output).split('run')[0]+'run{}.log'.format(currentInt)
		print >> open(path + 'Py_Script.out', 'a'), output
		ChangeFebio(febioFile, optimizeMaterialName, ['c1', 'k'], matGuess,path)  # Mooney-Rivlin
		# ChangeFebio(febioFile, optimizeMaterialName, 'E', matGuess))		# Neo-Hookean
		#
		print >> open(path + 'Py_Script.out','a') ,  "Calling subprocess32"
		
		#merge_str = 'sbatch /home/doherts/lustre/MULTIS/006_LA/FEBio_Test.txt'
		#ran into issues with the code proceeding before hpc could process the job, resulting in the submission of 100s of jobs, so we remove the log
		#subprocess32.call(["rm", output])
		subprocess32.call(["sbatch", path + 'FEBio_Test.txt'])		# run febio simulation by submitting batch job

		#Delays probably not needed anymore but I kept them to be safe
		time.sleep(300)
		BoolDone = False
		while BoolDone == False:
			#have to make sure file exists before we try to read it
			with open(output, "a+") as myfile:
				data=myfile.read()
			if 'T E R M I N A T I O N' in data:
				BoolDone=True
			else:
				with open(path + 'Py_Script.err', 'a+') as errorFile:
					data2 = errorFile.read()
				if 'sbatch: error:' in data2:
					#subprocess32.call(["rm", output])
					subprocess32.call(["sbatch", path + 'FEBio_Test.txt'])
					open(path + 'Py_Script.err', 'w').close()
					print >> open(path + 'Py_Script.out', 'a'), 'Resubmitting job because of HPC bugs'
					print >> open(path + 'Py_Script.out', 'a'), 'Size of err File: ', os.path.getsize(path + 'Py_Script.err')
					time.sleep(300)
				else:
					print >> open(path + 'Py_Script.out','a') ,  "Still running"
					time.sleep(300)
		#complete,converged,simForce= ReadFebioLog(output, expectedForce, epsilon)			# ReadFebioLog no longer used
		#

		ExpForce, ExpTime = get_ForceTime_XMLData(ManThickFile)
		ExpDisp = get_Disp_XMLData(RegistrationFile)
		print >> open(path + 'Py_Script.out','a') ,  "Got Exp Data"
		f = open(febioFile, 'r')
		feb_tree = et.parse(f)
		febio_root = feb_tree.getroot()
		print >> open(path + 'Py_Script.out','a') ,  "Opened .feb file: ", febioFile

		BCs = febio_root.find('Boundary').find('rigid_body').findall('prescribed')

		xFebDisp = float(BCs[0].text)
		yFebDisp = float(BCs[1].text)
		zFebDisp = float(BCs[2].text)
		MaxFebioDisplacement = np.sqrt(xFebDisp ** 2 + yFebDisp ** 2 + zFebDisp ** 2)

		print >> open(path + 'Py_Script.out','a') ,  "Got Max Febio disp: ", MaxFebioDisplacement



		print >> open(path + 'Py_Script.out','a') ,  "Got Experimental Data"
		#multiply last value of expdisp by 2 because of pulling the probe backwards to account for model overlap
		FebioForce, FebioDisp, FebioTime = get_ForceDisp_LogData(output, MaxFebioDisplacement)
		FebioDispClip, FebioForceClip = ClipFebioForceDisp(FebioDisp, FebioForce, ExpDisp, ExpForce)
		print >> open(path + 'Py_Script.out','a') ,  "Got Febio data"
		FebioSlope, FebioRSq = lin_fit(febioFile, FebioDispClip, FebioForceClip)
		ExpectedSlope, ExpectedRsq = lin_fit(RegistrationFile, ExpDisp, ExpForce)
		print >> open(path + 'Py_Script.out','a') ,  "Got Fits"
		#
		ratio = FebioSlope / ExpectedSlope
		matGuess = [m / float(ratio) for m in matGuess]

		with open(path + 'FEBio_Test.txt', 'r') as file:
			filedata = file.read()

		# Need to change the slurm text file to submit job with new name
		filedata = filedata.replace('run' + str(currentInt), 'run' + str(currentInt + 1))

		# Rename the file to keep HPC from overwriting log data
		with open(path + 'FEBio_Test.txt', 'w') as file:
			file.write(filedata)
		os.rename(febioFile.split('run')[0] + 'run{}.feb'.format(currentInt),
				  febioFile.split('run')[0] + 'run{}.feb'.format(currentInt + 1))

		currentInt += 1

		if 1.0 - epsilon < ratio < 1.0 + epsilon:
			converged = True
			print >> open(path + 'Py_Script.out', 'a'), 'Converged.'
			print >> open(path + 'Py_Script.out', 'a'), 'Final Slope is {}'.format(ExpectedSlope)
			matGuess = [m * float(ratio) for m in matGuess]
		else:
			converged = False
			print >> open(path + 'Py_Script.out', 'a'), 'FAILED to converge for this iteration.'
			print >> open(path + 'Py_Script.out', 'a'), 'estimated slope is {}'.format(ExpectedSlope)

		CallLog(convergeLog, complete, converged, ExpectedSlope, FebioSlope, matGuess, path)

	# Run the febio file one more time with final material parameter changes
	convergeLog.close()


if __name__ == '__main__':
	#Make sure to find and replace all the print commands for the linux pathways. The paths as I had them:
	# path = '/home/doherts/lustre/MULTIS/008_UL/'
	# FebioFile = path + '008UL_Quad_run1.feb'
	# RegistrationXML = path + 'CMULTIS008-1_UL_US_CT.xml'
	# ManThickXML = path + '003_CMULTIS008-1_UL_AC_I-1_manThick201708241020.xml'


	path = '/home/doherts/lustre/MULTIS/008_UL/'
	FebioFile = path + '008UL_Quad_run1.feb'
	RegistrationXML = path + 'CMULTIS008-1_UL_US_CT.xml'
	ManThickXML = path + '003_CMULTIS008-1_UL_AC_I-1_manThick201708241020.xml'

	RunInverseFEA(FebioFile, RegistrationXML, ManThickXML, path, epsilon=0.025, febioCommand=None)
