'''
femur_opt.py

optimize the Coxa-Trochanter muscle parameters with NSGA-II

optimization result of one generation under data/Optimization/CoTr

'''
import os,sys,time,logging
from shutil import copy
import geatpy as ea
import numpy as np
import matplotlib.pyplot as plt
import opensim as osim
import pandas as pd
from func_timeout import func_set_timeout
from pebble import ProcessPool
logging.basicConfig(level=logging.INFO)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import common

NIND = 300 # Population size 150
GEN = 20 # Generation number 40
MAX_WORKER = 15 #10
TIMEOUT = 480
NDIM = 3*9+3*6
LOGLEVEL = 'ERROR'

def task_done(future):
    try:
        result = future.result()  # blocks until results are ready
        print(result)
    except TimeoutError as error:
        print("Function took longer than %d seconds" % error.args[1])
    except Exception as error:
        print("Function raised %s" % error)
        # print(error.traceback)  # traceback of the function

def try_task(job,var,t):
    try:
        print("Job {} starts".format(job))
        so_fd(job,var,t)
    except Exception as e:
        logging.info("Error %s" % format(e))
        logging.info("{} reached {}s timeout".format(job,TIMEOUT))
        logging.info("{} default worst performance".format(job))

@func_set_timeout(TIMEOUT)
def so_fd(job,var,t):
    # FIXME put it into another module
    logging.info("Running job = {},{}".format(job, os.getpid()))
    # backup previous generations and rename with time
    # name with job Number to avoid conflict in reading
    osim.Logger.setLevel(common.LOG_LEVEL[LOGLEVEL])  # 3 WARN  4 ERROR
    osim.Logger.removeFileSink()
    logfile = 'opensim{}.log'.format(job)
    osim.Logger.addFileSink(logfile)
    foutput_name = 'new{}_locomotion_states_degrees.mot'.format(job)
    so = osim.AnalyzeTool('SO_setup_locomotion.xml')
    soname = 'NMF{}_locomotion'.format(job)
    so.setName(soname)
    
    fd = osim.ForwardTool('FD_setup_locomotion.xml',)
    fd.setPrintResultFiles(True)
    
    model = so.getModel()
    model.initSystem()
    working_state = model.getWorkingState()
    si = model.getWorkingState()

    for i, muscle in enumerate(model.getMuscles()):

        mname = muscle.getName()
        # logging.info(f'Muscle name: {mname}')
        # mp: 6 muscle parameters per MTU
        # mp[0]: max isometric force
        # mp[1]: max contraction velocity
        # mp[2]: optimal fiber length
        # mp[3]: ap insertion x
        # mp[4]: ap insertion y
        # mp[5]: ap insertion z
        if i < 3:
            mp = var[i * 9:(i + 1) * 9]
        else:
            mp = var[27+(i-3) * 6:27+(i-2) * 6]
        # print(mp)
        
        old_force = muscle.get_max_isometric_force()
        old_v = muscle.getMaxContractionVelocity()
        muscle.set_max_isometric_force(old_force*mp[0])
        muscle.setMaxContractionVelocity(old_v*mp[1])
        Lmtu1 = muscle.getOptimalFiberLength()
        Lmtu2 = muscle.getTendonSlackLength()
        Lmtu = Lmtu1 + Lmtu2
        opt_ratio = Lmtu1/Lmtu
        opt_ratio_new = min(opt_ratio * mp[2],0.95)
        lopt = Lmtu * opt_ratio_new
        lslack = Lmtu - lopt
        
        muscle.setOptimalFiberLength(lopt)
        muscle.setTendonSlackLength(lslack)
        # muscle.setOptimalFiberLength(old_lopt*mp[1])
        # muscle.setTendonSlackLength(old_tsl*mp[2])

        # set ap insertion
        geometry = muscle.getGeometryPath()
        path_point = None
        if i < 3:
            for idx, point in enumerate(geometry.getPathPointSet()):
                #iterate through pathpointset
                path_point = osim.PathPoint.safeDownCast(point)
                if idx==1:
                    initial_guess_vp = common.vec_to_list(path_point.get_location())
                    pathpoint_name_vp = path_point.getName()
                    pathpoint_name_vp = pathpoint_name_vp + '{}'.format(job)
                elif idx==2:
                    initial_guess_is = common.vec_to_list(path_point.get_location())
                    pathpoint_name_is = path_point.getName()
                    pathpoint_name_is = pathpoint_name_is + '{}'.format(job)
            # logging.info(f'MPS: {mp[3:]}')
            new_viapoint = osim.Vec3(
                mp[3] + initial_guess_vp[0],
                mp[4] + initial_guess_vp[1],
                mp[5] + initial_guess_vp[2]
            )
            new_insertion = osim.Vec3(
                mp[6] + initial_guess_is[0],
                mp[7] + initial_guess_is[1],
                mp[8] + initial_guess_is[2]
            )
            
            geometry.appendNewPathPoint(
                pathpoint_name_vp,
                geometry.getPathPointSet().get(1).getParentFrame(),
                new_viapoint
            )
            geometry.appendNewPathPoint(
                pathpoint_name_is,
                geometry.getPathPointSet().get(2).getParentFrame(),
                new_insertion
            )
            # addPathpoint may be interpolating
            geometry.deletePathPoint(working_state, 2)
            geometry.deletePathPoint(working_state, 1)
        else:
            for idx, point in enumerate(geometry.getPathPointSet()):
                #iterate through pathpointset
                path_point = osim.PathPoint.safeDownCast(point)

            initial_guess = common.vec_to_list(path_point.get_location())
            new_insertion = osim.Vec3(
                mp[3] + initial_guess[0],
                mp[4] + initial_guess[1],
                mp[5] + initial_guess[2]
            )
            pathpoint_name = path_point.getName()
            pathpoint_name = pathpoint_name + '{}'.format(job)
            geometry.appendNewPathPoint(
                pathpoint_name,
                geometry.getPathPointSet().get(idx).getParentFrame(),
                new_insertion
            )
            # addPathpoint may be interpolating
            geometry.deletePathPoint(working_state, idx)

    # logging.info("----------------- LOADING MODEL -----------------")
    model.initSystem()
    # logging.info("----------------- MODEL loaded -----------------")
    manager = osim.Manager(model)
    # model.printDetailedInfo(si)
    logging.info("{} ----------------- Locomotion SO starts -----------------".format(job))
    so.run()
    logging.info("{} ----------------- Locomotion SO completes -----------------".format(job))

    try:
        model.removeAnalysis(model.getAnalysisSet().get(0))
    except:
        pass
    ctrlfile = soname + '_StaticOptimization_activation.sto'
    fname = 'new{}_locomotion'.format(job)
    fd.setName(fname)
    fd.setControlsFileName(ctrlfile)
    fd.setSolveForEquilibrium(True)
    # fd.setStatesFileName(init_file) # FD followed by SO. init states automatically inferred from SO results
    fd.setModel(model)
    
    logging.info("{} ----------------- Locomotion FD starts -----------------".format(job))
    fd.run()
    logging.info("{} ----------------- Locomotion FD completes -----------------".format(job))

    #Benchmark to calculate RMSE; align if different time step
    _, data = common.read_motion_file(foutput_name)
    
    if data['time'].max() < 0.35:  #FD incomplete, quit
        print("job:{} not complete. FD time not enough.".format(job))
        return

    data = common.convert_time(data)
    _, origin_data = common.read_motion_file('locomotion_left_ref.mot')
    origin_data = common.convert_time(origin_data)
    origin_data_aligned, data_aligned = origin_data.align(data)
    data_aligned.interpolate(method='time')
    origin_aligned_degree = origin_data_aligned.interpolate(method='time')
    data_aligned_degree = data_aligned.interpolate(method='time')
    # Metrics: RMSE, correlation
    logging.info("{} -------------- Locomotion Evaluating --------------".format(job))

    _rmse = 0
    _pearson = 0
    doflist = ["yaw", "pitch", "roll"]
    for i,col_name in enumerate([f"/jointset/joint_LFTrochanter/joint_LFTrochanter_{dof}/value" for dof in doflist]):
        reference_angles = origin_aligned_degree[col_name]
        optimized_angles = data_aligned_degree[col_name]
        
        rmse_temp = common.calc_rmse(reference_angles, optimized_angles)
        pearson_temp = common.calc_pearson(reference_angles, optimized_angles)
        logging.info(f"rmse locomotion {doflist[i]}: {rmse_temp}")
        if rmse_temp> 100:
            raise Exception('Too big error in dof, breaking current optimization')
        if "pitch" in col_name:
            rmse_temp = rmse_temp * 3 
        
        _rmse += rmse_temp#/reference_range
        _pearson += pearson_temp

    # GROOMING
    foutput_name = 'new{}_grooming_states_degrees.mot'.format(job)
    so = osim.AnalyzeTool('SO_setup_grooming.xml')
    soname = 'NMF{}_grooming'.format(job)
    so.setName(soname)
    
    fd = osim.ForwardTool('FD_setup_grooming.xml',)
    fd.setPrintResultFiles(True)

    model = so.getModel()
    model.initSystem()
    working_state = model.getWorkingState()
    si = model.getWorkingState()

    for i, muscle in enumerate(model.getMuscles()):

        mname = muscle.getName()
        # logging.info(f'Muscle name: {mname}')
        # mp: 6 muscle parameters per MTU
        # mp[0]: max isometric force
        # mp[1]: max contraction velocity
        # mp[2]: optimal fiber length
        # mp[3]: ap insertion x
        # mp[4]: ap insertion y
        # mp[5]: ap insertion z
        if i < 3:
            mp = var[i * 9:(i + 1) * 9]
        else:
            mp = var[27+(i-3) * 6:27+(i-2) * 6]
        # print(mp)
        
        old_force = muscle.get_max_isometric_force()
        old_v = muscle.getMaxContractionVelocity()
        muscle.set_max_isometric_force(old_force*mp[0]*20)
        muscle.setMaxContractionVelocity(old_v*mp[1])
        Lmtu1 = muscle.getOptimalFiberLength()
        Lmtu2 = muscle.getTendonSlackLength()
        Lmtu = Lmtu1 + Lmtu2
        opt_ratio = Lmtu1/Lmtu
        opt_ratio_new = min(opt_ratio * mp[2],0.95)
        lopt = Lmtu * opt_ratio_new
        lslack = Lmtu - lopt
        # print('lopt=' ,lopt,Lmtu1,Lmtu2,Lmtu,mp[2])
        # print('OPTF',Lmtu, mp[1],lopt)
        
        # old_lopt = muscle.getOptimalFiberLength()
        # old_tsl = muscle.getTendonSlackLength()
        
        muscle.setOptimalFiberLength(lopt)
        muscle.setTendonSlackLength(lslack)
        # muscle.setOptimalFiberLength(old_lopt*mp[1])
        # muscle.setTendonSlackLength(old_tsl*mp[2])

        # set ap insertion
        geometry = muscle.getGeometryPath()
        path_point = None
        if i < 3:
            for idx, point in enumerate(geometry.getPathPointSet()):
                #iterate through pathpointset
                path_point = osim.PathPoint.safeDownCast(point)
                if idx==1:
                    initial_guess_vp = common.vec_to_list(path_point.get_location())
                    pathpoint_name_vp = path_point.getName()
                    pathpoint_name_vp = pathpoint_name_vp + '{}'.format(job)
                elif idx==2:
                    initial_guess_is = common.vec_to_list(path_point.get_location())
                    pathpoint_name_is = path_point.getName()
                    pathpoint_name_is = pathpoint_name_is + '{}'.format(job)
            # logging.info(f'MPS: {mp[3:]}')
            new_viapoint = osim.Vec3(
                mp[3] + initial_guess_vp[0],
                mp[4] + initial_guess_vp[1],
                mp[5] + initial_guess_vp[2]
            )
            new_insertion = osim.Vec3(
                mp[6] + initial_guess_is[0],
                mp[7] + initial_guess_is[1],
                mp[8] + initial_guess_is[2]
            )
            
            geometry.appendNewPathPoint(
                pathpoint_name_vp,
                geometry.getPathPointSet().get(1).getParentFrame(),
                new_viapoint
            )
            geometry.appendNewPathPoint(
                pathpoint_name_is,
                geometry.getPathPointSet().get(2).getParentFrame(),
                new_insertion
            )
            # addPathpoint may be interpolating
            geometry.deletePathPoint(working_state, 2)
            geometry.deletePathPoint(working_state, 1)
        else:
            for idx, point in enumerate(geometry.getPathPointSet()):
                #iterate through pathpointset
                path_point = osim.PathPoint.safeDownCast(point)

            initial_guess = common.vec_to_list(path_point.get_location())
            new_insertion = osim.Vec3(
                mp[3] + initial_guess[0],
                mp[4] + initial_guess[1],
                mp[5] + initial_guess[2]
            )
            pathpoint_name = path_point.getName()
            pathpoint_name = pathpoint_name + '{}'.format(job)
            geometry.appendNewPathPoint(
                pathpoint_name,
                geometry.getPathPointSet().get(idx).getParentFrame(),
                new_insertion
            )
            # addPathpoint may be interpolating
            geometry.deletePathPoint(working_state, idx)

    # logging.info("----------------- LOADING MODEL -----------------")
    model.initSystem()
    # logging.info("----------------- MODEL loaded -----------------")
    manager = osim.Manager(model)
    # model.printDetailedInfo(si)
    logging.info("{} ----------------- Grooming SO starts -----------------".format(job))
    so.run()
    logging.info("{} ----------------- Grooming SO completes -----------------".format(job))

    try:
        model.removeAnalysis(model.getAnalysisSet().get(0))
    except:
        pass
    ctrlfile = soname + '_StaticOptimization_activation.sto'
    fname = 'new{}_grooming'.format(job)
    
    fd.setName(fname)
    fd.setControlsFileName(ctrlfile)
    fd.setSolveForEquilibrium(True)
    # fd.setStatesFileName(init_file) # FD followed by SO. init states automatically inferred from SO results
    fd.setModel(model)
    
    logging.info("{} ----------------- Grooming FD starts -----------------".format(job))
    fd.run()
    logging.info("{} ----------------- Grooming FD completes -----------------".format(job))

    #Benchmark to calculate RMSE; align if different time step
    _, data = common.read_motion_file(foutput_name)
    
    if data['time'].max() < 0.35:  #FD incomplete, quit
        print("job:{} not complete. FD time not enough.".format(job))
        return

    data = common.convert_time(data)
    _, origin_data = common.read_motion_file('antgrooming_left_ref.mot')
    origin_data = common.convert_time(origin_data)
    origin_data_aligned, data_aligned = origin_data.align(data)
    data_aligned.interpolate(method='time')
    origin_aligned_degree = origin_data_aligned.interpolate(method='time')
    data_aligned_degree = data_aligned.interpolate(method='time')
    # Metrics: RMSE, correlation
    logging.info("{} -------------- Grooming Evaluating --------------".format(job))

    
    doflist = ["yaw", "pitch", "roll"]
    for i,col_name in enumerate([f"/jointset/joint_LFTrochanter/joint_LFTrochanter_{dof}/value" for dof in doflist]):
        reference_angles = origin_aligned_degree[col_name]
        optimized_angles = data_aligned_degree[col_name]
        rmse_temp = common.calc_rmse(reference_angles, optimized_angles)
        pearson_temp = common.calc_pearson(reference_angles, optimized_angles)
        logging.info(f"rmse grooming {doflist[i]}: {rmse_temp}")
        if rmse_temp> 100:
            raise Exception('Too big error in dof, breaking current optimization')
        _rmse += rmse_temp#/reference_range
        _pearson += pearson_temp
    
    logging.info(f"rmse total;: {_rmse}")
    logging.info(f"Pearson total: {_pearson}")
    msg = ''
    msg = msg + '{}\t{}\t'.format(_rmse, (6 - _pearson) * 100)
    msg = msg + "{}\t".format(job)  # write metrics numbered with jobs
    for v in var:  # write muscle parameters
        msg = msg + '{}\t'.format(v)
    msg = msg + '\n'
    # logging.info("-------------------- Printing RMSE and Pearson correlation -------------------")

    if msg.strip():
        with open("./rmse{}.mot".format(job), "w") as f:
            f.write(msg)
        logging.info("File written successfully.")
    else:
        logging.info("Empty message, file not saved.")
    model.printToXML("{}.osim".format(job))
    logging.info("job:{} complete.".format(job))
    # seed0.mot contains
    # RMSE, 1 - Pearson, Job ID, muscle parameters

class MyProblem(ea.Problem):  #
    def __init__(self):
        name = 'NSGA2'
        M = 2 # Objective count
        maxormins = [1] * M  # initialize maxormins（1：minimize target；-1：maximize target）
        Dim = NDIM # Parameter dimension muscle number  # 3 2ap-muscle + 3 1ap-muscle  

        varTypes = [] # initialize varTypes（0：real；1：integer）
        lb = [] # lower bound parameters
        ub = [] # upper bound parameters
        lbin = [] # lower bound include（0 exclude，1 include）
        ubin = [] # upper bound include（0 exclude，1 include）

        # lb.extend([1.5,  3.0, 0.4, -0.005, -0.005, -0.015, -0.005, -0.007, -0.005,])
        # ub.extend([3.0,  10.0, 0.8, 0, 0, 0.005, 0, 0.007, 0,])
        # lb.extend([0.8,  25.0, 0.3, 0, -0.015, 0, -0.005, 0, -0.002,])
        # ub.extend([2.0,  40.0, 0.7, 0.005, -0.005, 0.005, 0, 0.007, 0.010,])
        # lb.extend([1.5,  35.0, 0.8, -0.005, -0.02, 0.005, -0.015, -0.015, -0.007,])
        # ub.extend([2.0,  60.0, 0.9, 0.005, 0.00, 0.015, 0.005, -0.005, 0.005,])
        
        for i in range(3):  # conservative for miller 29ab
            varTypes.extend([0, 0, 0, 0, 0, 0, 0, 0, 0,])
            # Fiso, Vmax, lopt, vpx, vpy, vpz, apx, apy, apz
            lb.extend([0.3,  0.2, 0.5, -0.005, -0.005, -0.005, -0.003, -0.003, -0.003,])
            ub.extend([3.0,  4.0, 1.5, 0.005, 0.005, 0.005, 0.003, 0.003, 0.003,])
            lbin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])
            ubin.extend([1, 1, 1, 1, 1, 1, 1, 1, 1,])
        # lb.extend([0.2, 3.0, 0.5, -0.025, -0.025, 0.005,])
        # ub.extend([0.8, 15, 0.95, 0.02, 0.02, 0.015,])
        # lb.extend([1.5, 30, 0.5, 0, -0.005, -0.005,])
        # ub.extend([3.0, 60, 0.95, 0.005, 0.005, 0.005,])
        # lb.extend([0.4, 25, 0.5, 0, -0.005, -0.005,])
        # ub.extend([1.5, 55, 0.95, 0.005, 0.005, 0.005,])
        for i in range(3):  # 9 para per muscle
            varTypes.extend([0, 0, 0, 0, 0, 0,])
            # Fiso, Vmax, lopt, apx, apy, apz
            lb.extend([0.3,  0.2, 0.5, -0.005, -0.005, -0.005,])
            ub.extend([3.0,  4.0, 1.5, 0.005, 0.005, 0.005,])
            lbin.extend([1, 1, 1, 1, 1, 1,])
            ubin.extend([1, 1, 1, 1, 1, 1,])
        ea.Problem.__init__(self, name, M, maxormins, Dim, varTypes, lb, ub, lbin, ubin)

    def evalVars(self, Vars):  # target
        # backup to log/
        t = time.strftime("%H%M%S")
        if not os.path.exists("log"):
            os.makedirs("log")

        if os.path.exists("opensim.log"):
            os.rename("opensim.log", "log/opensim{}.log".format(t))  # log\\ --> log/
        # osim.Logger.setLevel(3)  #3 warn  4 error

        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            fn1 = 'new{}_states_degrees.mot'.format(i)
            if os.path.exists(fn):
                # Copy
                copy(fn, "log/{}".format(t) + fn)
            if os.path.exists(fn1):
                copy(fn1, "log/{}".format(t) + fn1)

        ObjV = []
        with ProcessPool(max_workers=MAX_WORKER, max_tasks=2*MAX_WORKER) as pool:
            for i in range(len(Vars)):
                var = Vars[i]
                arg = []
                arg.append(i)
                arg.append(var)
                arg.append(t)
                #write worst before call so-fd
                # logging.info("Assigned worst performance before process")
                msg = ''
                _rmse = 600
                _pearson = -1
                # from IPython import embed; embed()
                msg = msg + '{}\t{}\t'.format(_rmse, (6 - _pearson) * 100, )

                msg = msg + "{}\t".format(i)  # write metrics numbered with jobs
                for v in var:  # write muscle parameters
                    msg = msg + '{}\t'.format(v)
                msg = msg + '\n'
                if msg.strip():
                    with open("./rmse{}.mot".format(i), "w") as f:
                        f.write(msg)
                    logging.info("File written successfully.")
                else:
                    logging.info("Empty message, file not saved.")
                
                future = pool.schedule(try_task, arg, timeout=TIMEOUT)
                future.add_done_callback(task_done)
        CV = None
        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            if os.path.exists(fn):
                f = open(fn, 'r')
                lines = f.readlines()
                f.close()
                if len(lines) > 0:
                    for line in lines:
                        temp_split = line.split()
                        if len(temp_split) > 2:
                            ObjV.append([float(temp_split[0]),
                                        float(temp_split[1]),
                                        ])
                        else:
                            # Infinity
                            ObjV.append([999,999,])
                else:
                    logging.info("not enough lines")
                    ObjV.append([999,999,])
            else:  # missing files
                logging.info("missing file")
                ObjV.append([999,999,])
        print('ObjV',ObjV)
        O = np.array(ObjV)
        # logging.info("Objective values: ", ObjV)
        return O, CV


if __name__ == '__main__':
    problem = MyProblem()
    Encoding = 'RI'
    Field = ea.crtfld(Encoding, problem.varTypes, problem.ranges, problem.borders) # create field descriptor
    population = ea.Population(Encoding, Field, NIND)
    myAlgorithm = ea.moea_NSGA2_templet(problem, population)
    myAlgorithm.MAXGEN = GEN
    myAlgorithm.mutOper.F = 0.75  # mutation
    myAlgorithm.recOper.XOVR = 0.6  # cross-over
    myAlgorithm.logTras = 1  # logging every generation
    myAlgorithm.verbose = True
    myAlgorithm.drawing = 2  # set plot methods（0：no plot；1：result plot；2：evolve of Parator Frontal；3：state space anime）
    noParams = NDIM # muscles * parameters
    
    if os.path.exists("seed_new.txt"): # warm start if seed.mot exists
        f = open("seed_new.txt",'r')
        try:
            lines = f.readlines()
            prophetChrom = np.zeros((len(lines), noParams))
            i = 0
            for line in lines:
                temp_split = line.split()
                for j in range(noParams):
                    prophetChrom[i][j] = temp_split[j+3]
                    # skip RMSE, Pearson, Job ID
                i += 1
        finally:
            f.close()
        logging.info("Warm start, get init pops from seed.txt{}".format(prophetChrom.shape))
        prophetPop = ea.Population(Encoding, Field, len(lines), prophetChrom)
        myAlgorithm.call_aimFunc(prophetPop)
        [BestIndi, population] = myAlgorithm.run(prophetPop)  # get best individual of last generation; warm start
    else:
        [BestIndi, population] = myAlgorithm.run()   # get best individual of last generation; cold start
    BestIndi.save()

    logging.info('Evals Num：%s' % myAlgorithm.evalsNum)
    logging.info('Time eclapsed %s s' % myAlgorithm.passTime)
    if BestIndi.sizes != 0:
        logging.info('Best target so far：%s' % BestIndi.ObjV[0][0]) # might not indeed be the best; an alternative is to filter manually in last generation
        logging.info('Best parameters so far: ')
        for i in range(BestIndi.Phen.shape[1]):
            logging.info(BestIndi.Phen[0, i])
    else:
        logging.info('No feasible solution')





