'''
femur_opt.py

optimize the Coxa-Trochanter muscle parameters with NSGA-II

optimization result of one generation under data/Optimization/CoTr

'''

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import common
import time
import logging
from shutil import copy
import geatpy as ea
import numpy as np
import matplotlib.pyplot as plt
import opensim as osim
import pandas as pd
from func_timeout import func_set_timeout
from pebble import ProcessPool
logging.basicConfig(level=logging.INFO)

NIND = 15  # Population size 150
GEN = 2  # Generation number 40
MAX_WORKER = 15  # 10
TIMEOUT = 200
NDIM = 2*9
LOGLEVEL = 'ERROR'


def task_done(future):
    try:
        result = future.result()  # blocks until results are ready
        print(result)
    except TimeoutError as error:
        print("Function took longer than %d seconds" % error.args[1])
    except Exception as error:
        print("Function raised %s" % error)


def try_task(job, var, t):
    try:
        print("Job {} starts".format(job))
        so_fd(job, var, t)
    except Exception as e:
        logging.info("Error {}".format(e))
        logging.info("{} reached {}s timeout".format(job, TIMEOUT))
        logging.info("{} default worst performance".format(job))


@func_set_timeout(TIMEOUT)
def so_fd(job, var, t):
    # FIXME put it into another module
    logging.info("Running job = {},{}".format(job, os.getpid()))
    # backup previous generations and rename with time
    # name with job Number to avoid conflict in reading
    osim.Logger.setLevel(common.LOG_LEVEL[LOGLEVEL])  # 3 WARN  4 ERROR
    osim.Logger.removeFileSink()
    logfile = 'opensim{}.log'.format(job)
    osim.Logger.addFileSink(logfile)
    foutput_name = 'new{}_locomotion_states_degrees.mot'.format(job)
    so = osim.AnalyzeTool('SO_setup_locomotion.xml')
    soname = 'NMF{}_locomotion'.format(job)
    so.setName(soname)

    fd = osim.ForwardTool('FD_setup_locomotion.xml',)
    fd.setPrintResultFiles(True)

    model = so.getModel()
    model.initSystem()
    working_state = model.getWorkingState()
    si = model.getWorkingState()
    for i, muscle in enumerate(model.getMuscles()):

        mname = muscle.getName()
        # logging.info(f'Muscle name: {mname}')
        # mp: 6 muscle parameters per MTU
        # mp[0]: max isometric force
        # mp[1]: max contraction velocity
        # mp[2]: optimal fiber length
        # mp[3]: ap insertion x
        # mp[4]: ap insertion y
        # mp[5]: ap insertion z
        mp = var[i * 9:(i + 1) * 9]
        # print(mp)

        old_force = muscle.get_max_isometric_force()
        old_v = muscle.getMaxContractionVelocity()
        muscle.set_max_isometric_force(old_force*mp[0])
        muscle.setMaxContractionVelocity(old_v*mp[1])
        Lmtu1 = muscle.getOptimalFiberLength()
        Lmtu2 = muscle.getTendonSlackLength()
        Lmtu = Lmtu1 + Lmtu2
        opt_ratio = Lmtu1/Lmtu
        opt_ratio_new = min(opt_ratio * mp[2], 0.95)
        lopt = Lmtu * opt_ratio_new
        lslack = Lmtu - lopt

        muscle.setOptimalFiberLength(lopt)
        muscle.setTendonSlackLength(lslack)
        # muscle.setOptimalFiberLength(old_lopt*mp[1])
        # muscle.setTendonSlackLength(old_tsl*mp[2])

        # set ap insertion
        geometry = muscle.getGeometryPath()
        path_point = None
        for idx, point in enumerate(geometry.getPathPointSet()):
            # iterate through pathpointset
            path_point = osim.PathPoint.safeDownCast(point)
            if idx == 1:
                initial_guess_vp = common.vec_to_list(
                    path_point.get_location())
                pathpoint_name_vp = path_point.getName()
                pathpoint_name_vp = pathpoint_name_vp + '{}'.format(job)
            elif idx == 2:
                initial_guess_is = common.vec_to_list(
                    path_point.get_location())
                pathpoint_name_is = path_point.getName()
                pathpoint_name_is = pathpoint_name_is + '{}'.format(job)
        # logging.info(f'MPS: {mp[3:]}')
        new_viapoint = osim.Vec3(
            mp[3] + initial_guess_vp[0],
            mp[4] + initial_guess_vp[1],
            mp[5] + initial_guess_vp[2]
        )
        new_insertion = osim.Vec3(
            mp[6] + initial_guess_is[0],
            mp[7] + initial_guess_is[1],
            mp[8] + initial_guess_is[2]
        )

        geometry.appendNewPathPoint(
            pathpoint_name_vp,
            geometry.getPathPointSet().get(1).getParentFrame(),
            new_viapoint
        )
        geometry.appendNewPathPoint(
            pathpoint_name_is,
            geometry.getPathPointSet().get(2).getParentFrame(),
            new_insertion
        )
        # addPathpoint may be interpolating
        geometry.deletePathPoint(working_state, 2)
        geometry.deletePathPoint(working_state, 1)

    logging.info("----------------- LOADING MODEL -----------------")
    model.initSystem()
    logging.info("----------------- MODEL loaded -----------------")
    manager = osim.Manager(model)
    # model.printDetailedInfo(si)
    logging.info(
        "{} ----------------- Locomotion SO starts -----------------".format(job))
    so.run()
    logging.info(
        "{} ----------------- Locomotion SO completes -----------------".format(job))

    try:
        model.removeAnalysis(model.getAnalysisSet().get(0))
    except:
        pass
    ctrlfile = soname + '_StaticOptimization_activation.sto'
    fname = 'new{}_locomotion'.format(job)
    fd.setName(fname)
    fd.setControlsFileName(ctrlfile)
    fd.setSolveForEquilibrium(True)
    # fd.setStatesFileName(init_file) # FD followed by SO. init states automatically inferred from SO results
    fd.setModel(model)

    logging.info(
        "{} ----------------- Locomotion FD starts -----------------".format(job))
    fd.run()
    logging.info(
        "{} ----------------- Locomotion FD completes -----------------".format(job))

    # Benchmark to calculate RMSE; align if different time step
    _, data = common.read_motion_file(foutput_name)

    if data['time'].max() < 0.35:  # FD incomplete, quit
        print("job:{} not complete. FD time not enough.".format(job))
        return

    data = common.convert_time(data)
    _, origin_data = common.read_motion_file('locomotion_left_ref.mot')
    origin_data = common.convert_time(origin_data)
    origin_data_aligned, data_aligned = origin_data.align(data)
    data_aligned.interpolate(method='time')
    origin_aligned_degree = origin_data_aligned.interpolate(method='time')
    data_aligned_degree = data_aligned.interpolate(method='time')
    # Metrics: RMSE, correlation
    logging.info(
        "{} -------------- Locomotion Evaluating --------------".format(job))

    _rmse = 0
    _pearson = 0
    # doflist = ["yaw", "pitch", "roll"]
    col_name = "/jointset/joint_LFTibia/joint_LFTibia_pitch/value"
    reference_angles = origin_aligned_degree[col_name]
    optimized_angles = data_aligned_degree[col_name]

    rmse_temp = common.calc_rmse(reference_angles, optimized_angles)
    pearson_temp = common.calc_pearson(reference_angles, optimized_angles)
    logging.info(f"rmse locomotion pitch: {rmse_temp}")
    if rmse_temp > 100 or pearson_temp < 0.5:
        raise Exception('Too big error in dof, breaking current optimization')

    _rmse += rmse_temp  # /reference_range
    _pearson += pearson_temp

    # GROOMING
    foutput_name = 'new{}_grooming_states_degrees.mot'.format(job)
    so = osim.AnalyzeTool('SO_setup_grooming.xml')
    soname = 'NMF{}_grooming'.format(job)
    so.setName(soname)

    fd = osim.ForwardTool('FD_setup_grooming.xml',)
    fd.setPrintResultFiles(True)

    model = so.getModel()
    model.initSystem()
    working_state = model.getWorkingState()
    si = model.getWorkingState()

    for i, muscle in enumerate(model.getMuscles()):

        mname = muscle.getName()
        # logging.info(f'Muscle name: {mname}')
        # mp: 6 muscle parameters per MTU
        # mp[0]: max isometric force
        # mp[1]: max contraction velocity
        # mp[2]: optimal fiber length
        # mp[3]: ap insertion x
        # mp[4]: ap insertion y
        # mp[5]: ap insertion z
        mp = var[i * 9:(i + 1) * 9]

        old_force = muscle.get_max_isometric_force()
        old_v = muscle.getMaxContractionVelocity()
        muscle.set_max_isometric_force(old_force*mp[0])
        muscle.setMaxContractionVelocity(old_v*mp[1])
        Lmtu1 = muscle.getOptimalFiberLength()
        Lmtu2 = muscle.getTendonSlackLength()
        Lmtu = Lmtu1 + Lmtu2
        opt_ratio = Lmtu1/Lmtu
        opt_ratio_new = min(opt_ratio * mp[2], 0.95)
        lopt = Lmtu * opt_ratio_new
        lslack = Lmtu - lopt

        muscle.setOptimalFiberLength(lopt)
        muscle.setTendonSlackLength(lslack)

        # set ap insertion

    # logging.info("----------------- LOADING MODEL -----------------")
    model.initSystem()
    # logging.info("----------------- MODEL loaded -----------------")
    manager = osim.Manager(model)
    # model.printDetailedInfo(si)
    logging.info(
        "{} ----------------- Grooming SO starts -----------------".format(job))
    so.run()
    logging.info(
        "{} ----------------- Grooming SO completes -----------------".format(job))

    try:
        model.removeAnalysis(model.getAnalysisSet().get(0))
    except:
        pass
    ctrlfile = soname + '_StaticOptimization_activation.sto'
    fname = 'new{}_grooming'.format(job)

    fd.setName(fname)
    fd.setControlsFileName(ctrlfile)
    fd.setSolveForEquilibrium(True)
    fd.setModel(model)

    logging.info(
        "{} ----------------- Grooming FD starts -----------------".format(job))
    fd.run()
    logging.info(
        "{} ----------------- Grooming FD completes -----------------".format(job))

    # Benchmark to calculate RMSE; align if different time step
    _, data = common.read_motion_file(foutput_name)

    if data['time'].max() < 0.35:  # FD incomplete, quit
        print("job:{} not complete. FD time not enough.".format(job))
        return

    data = common.convert_time(data)
    _, origin_data = common.read_motion_file('antgrooming_left_ref.mot')
    origin_data = common.convert_time(origin_data)
    origin_data_aligned, data_aligned = origin_data.align(data)
    data_aligned.interpolate(method='time')
    origin_aligned_degree = origin_data_aligned.interpolate(method='time')
    data_aligned_degree = data_aligned.interpolate(method='time')
    # Metrics: RMSE, correlation
    logging.info(
        "{} -------------- Grooming Evaluating --------------".format(job))

    col_name = "/jointset/joint_LFTibia/joint_LFTibia_pitch/value"
    reference_angles = origin_aligned_degree[col_name]
    optimized_angles = data_aligned_degree[col_name]
    rmse_temp = common.calc_rmse(reference_angles, optimized_angles)
    pearson_temp = common.calc_pearson(reference_angles, optimized_angles)
    logging.info(f"rmse grooming pitch: {rmse_temp}")
    if rmse_temp > 100 or pearson_temp < 0.5:
        raise Exception('Too big error in dof, breaking current optimization')

    _rmse += rmse_temp  # /reference_range
    _pearson += pearson_temp

    logging.info(f"rmse total;: {_rmse}")
    logging.info(f"Pearson total: {_pearson}")
    msg = ''
    msg = msg + '{}\t{}\t'.format(_rmse, (2 - _pearson) * 100)
    msg = msg + "{}\t".format(job)  # write metrics numbered with jobs
    for v in var:  # write muscle parameters
        msg = msg + '{}\t'.format(v)
    msg = msg + '\n'
    # logging.info("-------------------- Printing RMSE and Pearson correlation -------------------")

    if msg.strip():
        with open("./rmse{}.mot".format(job), "w") as f:
            f.write(msg)
        logging.info("File written successfully.")
    else:
        logging.info("Empty message, file not saved.")
    model.printToXML("{}.osim".format(job))
    logging.info("job:{} complete.".format(job))
    # seed contains
    # RMSE, 1 - Pearson, Job ID, muscle parameters


class MyProblem(ea.Problem):  #
    def __init__(self):
        name = 'NSGA2'
        M = 2  # Objective count
        # initialize maxormins（1：minimize target；-1：maximize target）
        maxormins = [1] * M
        Dim = NDIM  # Parameter dimension muscle number  # 3 2ap-muscle + 3 1ap-muscle

        varTypes = []  # initialize varTypes（0：real；1：integer）
        lb = []  # lower bound parameters
        ub = []  # upper bound parameters
        lbin = []  # lower bound include（0 exclude，1 include）
        ubin = []  # upper bound include（0 exclude，1 include）

        for i in range(2):  # 9 para per muscle
            varTypes.extend([0, 0, 0, 0, 0, 0, 0, 0, 0,])
        lb.extend([0.3,  0.4, 0.8, -0.005, -0.005, -
                  0.005, -0.005, -0.005, -0.005])
        ub.extend([3.0,  2.4, 1.2, 0.005, 0.005, 0.005, 0.005, 0.005, 0.005])
        lbin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])
        ubin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])

        lb.extend([0.3,  0.4, 0.8, -0.005, -0.005, -0.005, -0.01, -0.01, -0.01])
        ub.extend([3.0,  2.4, 1.2, 0.005, 0.005, 0.005, 0.01, 0.01, 0.01])
        lbin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])
        ubin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])

        ea.Problem.__init__(self, name, M, maxormins, Dim,
                            varTypes, lb, ub, lbin, ubin)

    def evalVars(self, Vars):  # target
        # backup to log/
        t = time.strftime("%H%M%S")
        if not os.path.exists("log"):
            os.makedirs("log")

        if os.path.exists("opensim.log"):
            # log\\ --> log/
            os.rename("opensim.log", "log/opensim{}.log".format(t))
        # osim.Logger.setLevel(3)  #3 warn  4 error

        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            fn1 = 'new{}_states_degrees.mot'.format(i)
            if os.path.exists(fn):
                # Copy
                copy(fn, "log/{}".format(t) + fn)
            if os.path.exists(fn1):
                copy(fn1, "log/{}".format(t) + fn1)

        ObjV = []
        with ProcessPool(max_workers=MAX_WORKER, max_tasks=2*MAX_WORKER) as pool:
            for i in range(len(Vars)):
                var = Vars[i]
                arg = []
                arg.append(i)
                arg.append(var)
                arg.append(t)
                # write worst before call so-fd
                # logging.info("Assigned worst performance before process")
                msg = ''
                _rmse = 200
                _pearson = -1
                # from IPython import embed; embed()
                msg = msg + '{}\t{}\t'.format(_rmse, (2 - _pearson) * 100, )

                # write metrics numbered with jobs
                msg = msg + "{}\t".format(i)
                for v in var:  # write muscle parameters
                    msg = msg + '{}\t'.format(v)
                msg = msg + '\n'
                if msg.strip():
                    with open("./rmse{}.mot".format(i), "w") as f:
                        f.write(msg)
                    logging.info("File written successfully.")
                else:
                    logging.info("Empty message, file not saved.")

                future = pool.schedule(try_task, arg, timeout=TIMEOUT)
                future.add_done_callback(task_done)
        CV = None
        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            if os.path.exists(fn):
                f = open(fn, 'r')
                lines = f.readlines()
                f.close()
                if len(lines) > 0:
                    for line in lines:
                        temp_split = line.split()
                        if len(temp_split) > 2:
                            ObjV.append([float(temp_split[0]),
                                        float(temp_split[1]),
                                         ])
                        else:
                            # Infinity
                            ObjV.append([999, 999,])
                else:
                    logging.info("not enough lines")
                    ObjV.append([999, 999,])
            else:  # missing files
                logging.info("missing file")
                ObjV.append([999, 999,])
        print('ObjV', ObjV)
        O = np.array(ObjV)
        # logging.info("Objective values: ", ObjV)
        return O, CV


if __name__ == '__main__':
    problem = MyProblem()
    Encoding = 'RI'
    Field = ea.crtfld(Encoding, problem.varTypes, problem.ranges,
                      problem.borders)  # create field descriptor
    population = ea.Population(Encoding, Field, NIND)
    myAlgorithm = ea.moea_NSGA2_templet(problem, population)
    myAlgorithm.MAXGEN = GEN
    myAlgorithm.mutOper.F = 0.7  # mutation
    myAlgorithm.recOper.XOVR = 0.5  # cross-over
    myAlgorithm.logTras = 1  # logging every generation
    myAlgorithm.verbose = True
    # set plot methods（0：no plot；1：result plot；2：evolve of Parator Frontal；3：state space anime）
    myAlgorithm.drawing = 2
    noParams = NDIM  # muscles * parameters

    if os.path.exists("seed.txt"):  # warm start if seed.mot exists
        f = open("seed.txt", 'r')
        try:
            lines = f.readlines()
            prophetChrom = np.zeros((len(lines), noParams))
            i = 0
            for line in lines:
                temp_split = line.split()
                for j in range(noParams):
                    prophetChrom[i][j] = temp_split[j+3]
                    # skip RMSE, Pearson, Job ID
                i += 1
        finally:
            f.close()
        logging.info("Warm start, get init pops from seed.txt{}".format(
            prophetChrom.shape))
        prophetPop = ea.Population(Encoding, Field, len(lines), prophetChrom)
        myAlgorithm.call_aimFunc(prophetPop)
        # get best individual of last generation; warm start
        [BestIndi, population] = myAlgorithm.run(prophetPop)
    else:
        # get best individual of last generation; cold start
        [BestIndi, population] = myAlgorithm.run()
    BestIndi.save()

    logging.info('Evals Num：%s' % myAlgorithm.evalsNum)
    logging.info('Time eclapsed %s s' % myAlgorithm.passTime)
    if BestIndi.sizes != 0:
        # might not indeed be the best; an alternative is to filter manually in last generation
        logging.info('Best target so far：%s' % BestIndi.ObjV[0][0])
        logging.info('Best parameters so far: ')
        for i in range(BestIndi.Phen.shape[1]):
            logging.info(BestIndi.Phen[0, i])
    else:
        logging.info('No feasible solution')
