import numpy as np
from lxml import etree as et
import os
import StiffnessFromLog
import subprocess
import LogPostProcessing
from scipy import optimize
import sys
import random

def write_file(xml_tree, new_filename):

    # Write the New File
    xml_tree.write(new_filename, xml_declaration=True, pretty_print=True)

    # make sure that it's pretty printing properly, by parsing and re-writing
    parser = et.XMLParser(remove_blank_text=True)
    new_feb_tree = et.parse(new_filename, parser)
    new_feb_tree.write(new_filename, xml_declaration=True, pretty_print=True)


def write_to_summary_file(summary_file, text):
    info_file = open(summary_file, "a")
    info_file.write(text)
    info_file.close()


def change_mat_props(feb_file, mat_name, ratio, summary_file, new_name=None, **mat_props):
    """adjust the material properties in the file. default will overwrite existing file. if new_name is given,
    will create a new file. (path should be included in new_name)"""

    write_to_summary_file(summary_file,'changing material properties in the following file: {} \n'.format(feb_file))

    coupled = False # assume uncoupled, check if coupled later

    file_tree = et.parse(feb_file)
    root = file_tree.getroot()
    Materials = root.find('Material')
    for mat in Materials:
        if mat.attrib['name'] == mat_name:

            # change any material properties requested
            for prop_name, prop_val in iter(mat_props.items()):
                prop = mat.find(prop_name)
                prop.text = "{}".format(prop_val)
                write_to_summary_file(summary_file,'setting {} to {} \n'.format(prop_name, prop_val))

            # check if its a coupled material
            mat_type = mat.attrib['type']
            if "coupled" in mat_type:
                coupled = True

            # unless k is explicity defined, set it using the ratio.
            if "k" not in mat_props.keys():
                if coupled: # if coupled material, set k as ratio*c1
                    k_prop = mat.find("k")
                    c1 = float(mat.find("c1").text)
                    k_prop.text= "{}".format(ratio*c1)
                    write_to_summary_file(summary_file,'setting k to {} for coupled formulation \n \n'.format(ratio*c1))

                else: # if uncoupled, set k as ratio*c5
                    k_prop = mat.find("k")
                    c5 = float(mat.find("c5").text)
                    k_prop.text = "{}".format(ratio * c5)
                    write_to_summary_file(summary_file,'setting k to {} for uncoupled formulation \n \n'.format(ratio * c5))

            break # only need to do it for the one material we are interested in, can stop looping

    # write the new file
    if new_name is not None:
        write_file(file_tree, new_name)
    else:
        write_file(file_tree, feb_file)

def measure_outputs(feb_file, rb_name):
    """measure the outputs of a model"""

    log_file = feb_file.replace('.feb','.log')
    rigid_bodies = LogPostProcessing.find_rigid_bodies(log_file)
    all_data, time_steps = LogPostProcessing.collect_data(log_file)

    # for the total reaction force, need to look at the data from the last time step, on the rigid body
    rb_reaction = all_data['Reaction_Forces'][rigid_bodies[rb_name].id][-1]
    total_reaction_force = np.linalg.norm(rb_reaction)

    return total_reaction_force


def calibrate_c1(c1, febioCommand, mat_name, compression_feb, ratio, summary_file, c5):

    iters = 0
    change = 100
    total_reaction = 10000000
    next_c1= c1

    while total_reaction > 1:

        if iters > 3:
            write_to_summary_file(summary_file, "terminated c1 calibration, exceeded max iters. likely k is too high \n \n")
            break

        if change < 5:
            write_to_summary_file(summary_file, "terminated c1 calibration, reduction in c1 not reducing reation force. likely k is too high \n \n")
            break

        c1 = next_c1
        change_mat_props(compression_feb, mat_name, ratio, summary_file, c1=c1, c5=c5)

        # run the compression model
        print("running the compression model")
        subprocess.call([febioCommand, compression_feb])

        prev_reaction = total_reaction

        # measure the total reaction force
        total_reaction = measure_outputs(compression_feb, 'FMB')
        print("total reaction force in compression {}".format(total_reaction))

        # check what the change is from the previous reaction force
        change = abs((prev_reaction - total_reaction)/prev_reaction)*100

        write_to_summary_file(summary_file, 'total reaction force in compression: {} \n \n'.format(total_reaction))

        # keep reducing c1 until reaction force is less than 1N
        next_c1 = c1/(2*total_reaction)

        iters += 1

    write_to_summary_file(summary_file, 'calibrated c1 value: {} \n'.format(c1))
    write_to_summary_file(summary_file, "---------------------------------------- \n \n")

    return c1, total_reaction


def calibrate_c5_auto(c5, expected_stiffness, ratio, febioCommand, tension_feb, mat_name, summary_file, c1):
    """ optimization of c5, optimizaiton function takes in c5, minimizes the difference between the measured
    and expected stiffness"""

    def opt_fun(c5):

        if c5 < 0:
            return 100000

        change_mat_props(tension_feb, mat_name, ratio, summary_file, c5=c5, c1=c1)

        write_to_summary_file(summary_file, "running tension model \n ")
        subprocess.call([febioCommand, tension_feb])

        # if model failed to converge, penalize the output
        logfile = tension_feb.replace('.feb','.log')
        with open(logfile, 'r') as f:

            second_last_line = list(f)[-2]
            if 'E R R O R   T E R M I N A T I O N' in second_last_line:
                write_to_summary_file(summary_file, 'ERROR TERMINATION \n')
                return 100000

        # measure the linear stiffness (check that it reached the linear region?)
        _,_,_,s = StiffnessFromLog.FindStiffness(tension_feb.replace('.feb','.log'), 'FMB')

        write_to_summary_file(summary_file, 'measured stiffness in tension: {}, expected stiffness: {} \n'.format(s, expected_stiffness))
        dif = abs(s - expected_stiffness)
        write_to_summary_file(summary_file, 'difference: {} \n \n'.format(dif))

        return dif

    write_to_summary_file(summary_file, "Beginning optimization \n\n")

    res = optimize.minimize_scalar(opt_fun, bracket = (c5, c5+100),method='brent',tol=0.1, options={'xtol': 0.1, 'maxiter': 50})
    c5 = res.x

    write_to_summary_file(summary_file, 'optimized c5 value: {} \n'.format(c5))
    write_to_summary_file(summary_file, "---------------------------------------- \n \n")

    # change to converged result and run tension model again
    write_to_summary_file(summary_file, 'running tension model with optimized c5 \n \n')
    change_mat_props(tension_feb, mat_name, ratio, summary_file, c5=c5)
    subprocess.call([febioCommand, tension_feb])

    # measure the linear stiffness (check that it reached the linear region?)
    _, _, _, s = StiffnessFromLog.FindStiffness(tension_feb.replace('.feb', '.log'), 'FMB')
    # measure total reaction
    total_reaction = measure_outputs(tension_feb, 'FMB')

    return c5, s, total_reaction


def calibrate_c5(c5, expected_stiffness, ratio, febioCommand, tension_feb, mat_name, summary_file, c1):

    write_to_summary_file(summary_file, "Beginning optimization \n\n")

    difference = 100
    c5_next  = c5

    c5_above = 0
    s_above = 0
    c5_below = 0
    s_below = 0

    error_count = 0
    success = True

    while abs(difference) > 0.5:

        # try the next guess for c5
        c5 = c5_next

        # change c5 in the tension febio file and run the model
        change_mat_props(tension_feb, mat_name, ratio, summary_file, c5=c5, c1=c1)

        write_to_summary_file(summary_file, "running tension model \n ")
        subprocess.call([febioCommand, tension_feb])

        # if model failed to converge, try halving c5 to bring it into a convergence zone
        logfile = tension_feb.replace('.feb', '.log')
        with open(logfile, 'r') as f:

            second_last_line = list(f)[-2]
            if 'E R R O R   T E R M I N A T I O N' in second_last_line:
                write_to_summary_file(summary_file, 'ERROR TERMINATION \n')
                difference = 100
                # scale back the next guess randomly to avoid ending up in and endless loop
                rand_scale = random.uniform(0.4, 0.9)
                c5_next = rand_scale*c5
                # count how many times error termination occurs. If it happens more than 15 times, exit
                error_count += 1
                if error_count > 15:
                    write_to_summary_file(summary_file, "optimization stopped - too many error terminations\n")
                    success = False
                    break
                continue

        # measure the linear stiffness (check that it reached the linear region?)
        _, _, _, s = StiffnessFromLog.FindStiffness(tension_feb.replace('.feb', '.log'), 'FMB')

        write_to_summary_file(summary_file, 'measured stiffness in tension: {}, expected stiffness: {} \n'.format(s,
                                                                                                                  expected_stiffness))
        # check the difference
        difference = (s - expected_stiffness)
        write_to_summary_file(summary_file, 'difference: {} \n \n'.format(difference))

        # set the guess as above or below the solution
        if difference > 0:
            c5_above = c5
            s_above = s
        else:
            c5_below = c5
            s_below = s

        # linear interpolation between the guess above and the guess below
        # (if one is still unassigned it will just interpolate with 0 to find the next guess)
        m = (s_above - s_below)/(c5_above-c5_below)
        c5_next = ((expected_stiffness - s_below)/m) + c5_below

        # if the solution gets stuck
        if abs(c5_next - c5) < 0.05:
            write_to_summary_file(summary_file,"optimization stopped\n")
            break

    # if the calibration did not succeed, exit the function
    if not success:
        return None,None,None,success

    write_to_summary_file(summary_file, 'optimized c5 value: {} \n'.format(c5))
    write_to_summary_file(summary_file, "---------------------------------------- \n \n")

    # change to converged result and run tension model again
    write_to_summary_file(summary_file, 'running tension model with optimized c5 \n \n')
    change_mat_props(tension_feb, mat_name, ratio, summary_file, c5=c5)
    subprocess.call([febioCommand, tension_feb])

    # measure the linear stiffness (check that it reached the linear region?)
    _, _, _, s = StiffnessFromLog.FindStiffness(tension_feb.replace('.feb', '.log'), 'FMB')
    # measure total reaction
    total_reaction = measure_outputs(tension_feb, 'FMB')

    return c5, s, total_reaction, success


def check_c1_reduction(summary_file, total_reaction_calib, s_calib, mat_name, ratio, c1, tension_feb, febioCommand):
    """ perturb c1 and check influence on model results"""

    write_to_summary_file(summary_file, 'Checking that tension results are not sensitive to change in c1 \n \n' +
                          'Current total reaction force in tension: {} \n'.format(total_reaction_calib) +
                          'Current stiffness in tension: {} \n \n'.format(s_calib))

    c1_change = 0.2 # will reduce c1 by 20%

    change_mat_props(tension_feb, mat_name, ratio, summary_file, c1 = (1-c1_change)*c1)
    subprocess.call([febioCommand, tension_feb])

    # check for error termination, if that's the case, try reduction by 5 instead of 10
    logfile = tension_feb.replace('.feb', '.log')
    with open(logfile, 'r') as f:

        second_last_line = list(f)[-2]
        if 'E R R O R   T E R M I N A T I O N' in second_last_line:
            write_to_summary_file(summary_file, 'ERROR TERMINATION \n')


    total_reaction_red = measure_outputs(tension_feb, 'FMB')
    _, _, _, s_red = StiffnessFromLog.FindStiffness(tension_feb.replace('.feb', '.log'), 'FMB')

    total_reaction_change = abs((total_reaction_red - total_reaction_calib) / total_reaction_calib)
    s_change = abs((s_red - s_calib)/ s_calib)

    max_result_change = max(total_reaction_change, s_change)
    change_ratio = max_result_change/c1_change

    write_to_summary_file(summary_file, 'total reaction force in tension after reduction of c1 by 20%: {} \n'.format(total_reaction_red)+
                          'stiffness in tension after reduction of c1 by 20%: {} \n \n'.format(s_red)+
                          "ratio of maximum result change/c1 change: {} \n".format(change_ratio))

    if change_ratio > 0.25:
        c1 = c1/2

    return change_ratio, c1

def find_geometry_tree(febfile):
    """Find the geometry section in case it is referenced to in a different file.
    Retrun the geometry section, the geometry tree, and the geometry file name """

    # find the geometry file
    feb_tree = et.parse(febfile)
    febio_spec_root = feb_tree.getroot()

    geo_sect_in_feb = febio_spec_root.find('Geometry')

    if geo_sect_in_feb.attrib['from'] is not None:

        geofile = geo_sect_in_feb.attrib['from']

        if len(os.path.dirname(geofile))>0: # if the pathway is given:
            pass
        else:  # if no pathway is given, assume the pathway is the same as the febio file:
            dir = os.path.dirname(febfile)
            geofile = os.path.join(dir, geofile)

        geo_tree = et.parse(geofile)

    else:
        geo_tree = feb_tree
        geofile = None

    return geo_tree, geofile

def oriented_bounding_box(Nodes):

    ca = np.cov(Nodes, y=None, rowvar=False, bias=True)

    v, vect = np.linalg.eig(ca)
    tvect = np.transpose(vect)

    # use the inverse of the eigenvectors as a rotation matrix and
    # rotate the points so they align with the x and y axes
    ar = np.dot(Nodes, np.linalg.inv(tvect))

    # get the minimum and maximum x,y, and z
    mina = np.min(ar, axis=0)
    maxa = np.max(ar, axis=0)
    diff = (maxa - mina) * 0.5

    # the center is just half way between the min and max xyz
    center = mina + diff

    # get the 8 corners by subtracting and adding half the bounding boxes height, length and width to the center
    corners = np.array([center + [-diff[0], -diff[1], -diff[2]], center + [diff[0], -diff[1], -diff[2]],
                        center + [diff[0], diff[1], -diff[2]], center + [-diff[0], diff[1], -diff[2]],
                        center + [-diff[0], -diff[1], diff[2]], center + [diff[0], -diff[1], diff[2]],
                        center + [diff[0], diff[1], diff[2]], center + [-diff[0], diff[1], diff[2]]])

    # use the the eigenvectors as a rotation matrix and
    # rotate the corners and the centerback
    corners = np.dot(corners, tvect)
    center = np.dot(center, tvect)

    # get the vectors of the 3 edges
    edges = np.array([corners[1] - corners[0], corners[4] - corners[0], corners[3] - corners[0]])
    edge_lengths = np.linalg.norm(edges, axis=1)
    edge_vectors = edges/edge_lengths[:,None]

    return edge_lengths, edge_vectors, center

def GetNodes(febfile, mat_name):
    """Reads the Geometry section of an Febio input file, and Returns all the Geometry_Parts in a dictionary """

    geo_tree, _= find_geometry_tree(febfile)
    febio_spec_root_geo = geo_tree.getroot()
    geometry_section = febio_spec_root_geo.find('Geometry')

    def get_data(geom):
        """Reads the data from a geometry section subelement such as Nodes, Elements"""
        data_dict = {}
        for point in geom:
            try:
                point_id = int(point.attrib['id'])
                data_as_str = point.text.split(',')
                point_data = [float(x) for x in data_as_str]
                data_dict[point_id] = point_data
            except: # if its a comment
                pass
        return data_dict

    Nodes_sections=geometry_section.findall('Nodes')

    # find the node section for the material, store the nodes in an array
    for node_data in Nodes_sections:
        try:
            nodeset_name = node_data.attrib['name']
            if nodeset_name == mat_name:
                node_dict = get_data(node_data)
                nodes = np.asarray(list(node_dict.values()))
        except: pass

    return nodes

def change_load_curve(feb_file, dispalcement):

    file_tree = et.parse(feb_file)
    root = file_tree.getroot()
    LoadData = root.find('LoadData')
    load_curve = LoadData.find('loadcurve')
    for point in load_curve:
        if point.text[0] == '1':
            # keep the sign the same as before, but change magnitude
            current_disp_sign = np.sign(float(point.text.split(',')[-1]))
            point.text = '1,{0:.2f}'.format(current_disp_sign*dispalcement)

    #re-write the file
    write_file(file_tree, feb_file)

def adjust_disaplcement(compression_model, tension_model, mat_name):
    """ get the approximate length of the ligament, and adjust displacement to 10% strain"""

    # get the longest edge lenght of the oriented bounding box of the ligament
    mat_nodes = GetNodes(tension_model, mat_name)
    edge_lengths, _, _ = oriented_bounding_box(mat_nodes)
    longest_length = max(edge_lengths)

    displacement= 0.075*longest_length

    # change the load curve so that the strain is set to the correct value
    change_load_curve(tension_model, displacement)
    change_load_curve(compression_model, displacement)


def calibrate_materials(compression_model, tension_model, mat_name, expected_stiffness, ratio,febioCommand = None):

    # initial guesses for coefficients
    c1 = 1.95
    c5 = 535

    # c1 = 0.050
    # c5 = 13563.74
    # check the ligament length, adjust displacement for tension, compression models to 10% nominal strain
    adjust_disaplcement(compression_model, tension_model, mat_name)

    if not febioCommand:
        febioCommand = '/home/schwara2/Programs/FEBio/FEBio2.8.2/bin/febio2.lnx64'
        # febioCommand = "C:\\Users\schwara2\Programs\Febio\\bin\FEBio2.exe"
        print('Using {} to call febio'.format(febioCommand))

    # go to the folder containing the model files
    dirname = os.path.dirname(compression_model)
    os.chdir(dirname)

    summary_file = os.path.join(dirname, 'Summary.txt')

    # get the base file names for compression and tension models
    compression_feb = os.path.basename(compression_model)
    tension_feb = os.path.basename(tension_model)

    # set the intial guess for c1 and c5. first, check if its a coupled or uncoupled material
    coupled = False
    file_tree = et.parse(compression_feb)
    root = file_tree.getroot()
    Materials = root.find('Material')
    for mat in Materials:
        if mat.attrib['name'] == mat_name:
            # check if its a coupled material
            mat_type = mat.attrib['type']
            if "coupled" in mat_type:
                coupled = True
            break

    write_to_summary_file(summary_file, '================================================================================== \n'+
                          'intial guess c1: {} \n'.format(c1)+
                          'intial guess c5: {} \n'.format(c5)+
                          'ratio: {} \n'.format(ratio)+
                          'coupled: {}\n \n'.format(coupled))

    write_to_summary_file(summary_file, 'c1 calibration using compression model {} \n \n'.format(compression_feb))
    c1, compression_reaction = calibrate_c1(c1, febioCommand, mat_name, compression_feb, ratio, summary_file, c5)

    change_ratio = 1
    success = True

    while change_ratio > 0.25:

        write_to_summary_file(summary_file, 'c5 calibration using tension model {} \n \n'.format(tension_feb))
        c5, s_calib, tension_reaction_calib, success = calibrate_c5(c5, expected_stiffness, ratio, febioCommand, tension_feb, mat_name, summary_file, c1)

        if not success:
            break

        # check the influence of perturbing c1 on the tension model
        # this will return the same c1 if the change ratio is ok, c1/10 if change ratio is too large
        change_ratio, c1 = check_c1_reduction(summary_file, tension_reaction_calib, s_calib, mat_name, ratio, c1, tension_feb, febioCommand)

    # the the calibration process did not succeed, exit the function
    if not success:
        return

    write_to_summary_file(summary_file, 'running compression and tension model with final calibrated results \n \n')

    # run compression and tension models with final calibrated material properties
    # save these final models under a new name
    final_compression =compression_feb.replace('.feb', '_r{}.feb'.format(ratio))
    change_mat_props(compression_feb, mat_name, ratio, summary_file,
                     new_name = final_compression, c1 = c1, c5 = c5)
    subprocess.call([febioCommand, final_compression])
    compression_reaction = measure_outputs(final_compression, 'FMB')

    final_tension = tension_feb.replace('.feb', '_r{}.feb'.format(ratio))
    change_mat_props(tension_feb, mat_name, ratio, summary_file,
                     new_name = final_tension, c1=c1, c5=c5)
    subprocess.call([febioCommand, final_tension])
    _, _, _, stiffness = StiffnessFromLog.FindStiffness(final_tension.replace('.feb', '.log'), 'FMB')

    write_to_summary_file(summary_file, 'final results: \n'+"c1 = {} \n".format(c1)+"c5 = {} \n".format(c5))
    if coupled:
        write_to_summary_file(summary_file, "k = {} \n \n".format(ratio*c1))
    else:
        write_to_summary_file(summary_file,"k = {} \n \n".format(ratio * c5))

    write_to_summary_file(summary_file,'stiffness in tension: {} \n'.format(stiffness) +'compression reaction force: {} \n'.format(compression_reaction))


def calibrate_all_knees(models_dir):
    """calibrate each knee model at ratios of 1-10,000. models dir should contain one folder for each knee, titled with the name of the specimen, ex: oks001.
    Each knee folder should contain FeBio_custom_compression.feb, and FeBio_custom_tension.feb"""

    knee_stiffness = {"oks001":180, "oks002":180, "oks003":242, "oks004":220, "oks006":180,
                      "oks007":180, "oks008":220, "oks009":242}

    # for each knee, calibrate at ratios of 1-10000
    for knee in os.listdir(models_dir):
        compression_model  = os.path.join(models_dir, knee, "FeBio_custom_compression.feb")
        tension_model = os.path.join(models_dir, knee, "FeBio_custom_tension.feb")
        expected_stiffness = knee_stiffness[knee]
        for r in np.logspace(0,4,num=5):
            calibrate_materials(compression_model, tension_model, "ACL", expected_stiffness, r)

if __name__ == '__main__':

    # for the ACL continuum modeling study, the input for this function is the directory with 1 folder for each knee,
    # and each knee folder contains a tention model and a compression model

    # calibrate_all_knees(*sys.argv[1:])

    # to run the material properties calibration for a single ligament , provide the compression and tension models
    # (reduced model containting only bones and the ligament in question), the expected stiffness,
    # the ratio being used and run calibrate_materials()

    compression_model = "C:\\Users\schwara2\Documents\Open_Knees\\app\ACL_modeling\\Coupled_Models\oks009\FeBio_custom_compression.feb"
    tension_model = "C:\\Users\schwara2\Documents\Open_Knees\\app\ACL_modeling\\Coupled_Models\oks009\FeBio_custom_tension.feb"
    expected_stiffness = 242
    ratio = 1
    calibrate_materials(compression_model, tension_model, "ACL", expected_stiffness, ratio)

