'''
coxa_opt.py

optimize the Thorax-coxa muscle parameters with NSGA-II

optimization result of one generation under data/Optimization/ThCo
'''
import os
import sys
import time
import random
import multiprocessing
from shutil import copy
import geatpy as ea
import numpy as np
import matplotlib.pyplot as plt
import opensim as osim
import pandas as pd
from scipy import stats
from func_timeout import func_set_timeout
import func_timeout
import logging

logging.basicConfig(level=logging.INFO)

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import common

def try_task(job,var,t):
    try:
        so_fd(job,var,t)
    except func_timeout.exceptions.FunctionTimedOut:
        print("{} reached 900s timeout".format(job))

def run_so_fd_pipeline(job, var, t, behavior='loco'):
    # FIXME put it into another module
    logging.info("Running job = {},{},{}".format(job, os.getpid(),behavior))
    # backup previous generations and rename with time
    # name with job Number to avoid conflict in reading
    osim.Logger.setLevel(common.LOG_LEVEL['ERROR'])  # 3 warn  4 error
    osim.Logger.removeFileSink()
    logfile = 'opensim{}.log'.format(job)
    osim.Logger.addFileSink(logfile)

    foutput_name = 'new{}_{}_states_degrees.mot'.format(job, behavior)
    so = osim.AnalyzeTool('SO_setup_{}.xml'.format(behavior))
    soname = 'NMF{}_{}'.format(job, behavior)
    so.setName(soname)

    fd = osim.ForwardTool('FD_setup_{}.xml'.format(behavior),)
    fd.setPrintResultFiles(True)

    model = so.getModel()
    model.initSystem()
    working_state = model.getWorkingState()
    si = model.getWorkingState()

    for i, muscle in enumerate(model.getMuscles()):

        mname = muscle.getName()
        logging.info(f'Muscle name: {mname}')
        # mp: 6 muscle parameters per MTU
        # mp[0]: max isometric force
        # mp[1]: max contraction velocity
        # mp[2]: optimal fiber length
        # mp[3]: ap insertion x
        # mp[4]: ap insertion y
        # mp[5]: ap insertion z
        mp = var[i * 6:(i + 1) * 6]
        muscle.set_max_isometric_force(mp[0])
        muscle.setMaxContractionVelocity(mp[1])

        Lmtu1 = muscle.getOptimalFiberLength()
        Lmtu2 = muscle.getTendonSlackLength()
        Lmtu = Lmtu1 + Lmtu2
        lopt = Lmtu * mp[2]
        lslack = Lmtu - lopt
        # print('lopt=' ,lopt,Lmtu1,Lmtu2,Lmtu,mp[2])
        muscle.setOptimalFiberLength(lopt)
        muscle.setTendonSlackLength(lslack)

        # set ap insertion
        geometry = muscle.getGeometryPath()
        path_point = None

        for idx, point in enumerate(geometry.getPathPointSet()):
            # iterate through pathpointset
            path_point = osim.PathPoint.safeDownCast(point)
            # set last ap (insertion)

        initial_guess = common.vec_to_list(path_point.get_location())
        new_insertion = osim.Vec3(
            mp[3] + initial_guess[0],
            mp[4] + initial_guess[1],
            mp[5] + initial_guess[2]
        )
        pathpoint_name = path_point.getName()
        pathpoint_name = pathpoint_name + '{}'.format(job)
        geometry.appendNewPathPoint(
            pathpoint_name,
            geometry.getPathPointSet().get(idx).getParentFrame(),
            new_insertion
        )
        # addPathpoint may be interpolating
        geometry.deletePathPoint(working_state, idx)

    model.initSystem()
    manager = osim.Manager(model)
    model.printDetailedInfo(si)
    # frq=var[24]  #low-pass filter fixed at 39Hz
    # so.setLowpassCutoffFrequency(frq)
    so.run()

    logging.info("----------------- SO complete -----------------")

    # embed()

    model.removeAnalysis(model.getAnalysisSet().get(0))
    ctrlfile = soname + '_StaticOptimization_activation.sto'
    logging.debug(ctrlfile)
    fname = 'new{}_{}'.format(job, behavior)
    fd.setName(fname)
    fd.setControlsFileName(ctrlfile)
    fd.setSolveForEquilibrium(eq)
    # fd.setStatesFileName(init_file) # FD followed by SO. init states automatically inferred from SO results
    fd.setModel(model)
    fd.run()

    logging.info("----------------- FD complete -----------------")

    return foutput_name

def align_data(opt_data, origin_data):
    origin_data_aligned, data_aligned = origin_data.align(opt_data)
    data_aligned.interpolate(method='time')
    origin_aligned_degree = origin_data_aligned.interpolate(method='time')
    data_aligned_degree = data_aligned.interpolate(method='time')

    return origin_aligned_degree, data_aligned_degree

def load_data(foutput):
    _, _data = common.read_motion_file(foutput)
    data = common.convert_time(_data)

    return data


@func_set_timeout(1200)
def so_fd(job, var, t):

    # Run it on locomotion
    logging.info("-------------- Running locomotion --------------")
    fout_loco = run_so_fd_pipeline(job, var, t, behavior='loco')
    loco_kin_opt = load_data(fout_loco)
    loco_kin_ref = load_data('locomotion_left_ref.mot')
    loco_kin_ref_aligned, loco_kin_opt_aligned = align_data(loco_kin_opt, loco_kin_ref)
    # Run it on grooming
    logging.info("-------------- Running grooming --------------")
    fout_groom = run_so_fd_pipeline(job, var, t, behavior='groom')
    groom_kin_opt = load_data(fout_groom)
    groom_kin_ref = load_data('antgrooming_left_ref.mot')
    groom_kin_ref_aligned, groom_kin_opt_aligned = align_data(groom_kin_opt, groom_kin_ref)

    # Metrics: RMSE, correlation
    logging.info("-------------- Calculating RMSE and Pearson correlation --------------")

    msg = ''
    _rmse = 0
    _pearson = 0
    for col_name in [f"/jointset/joint_LFCoxa/joint_LFCoxa_{dof}/value" for dof in ["yaw", "pitch", "roll"]]:
        ref_groom = groom_kin_ref_aligned[col_name]
        opt_groom = groom_kin_opt_aligned[col_name]

        ref_loco = loco_kin_ref_aligned[col_name]
        opt_loco = loco_kin_opt_aligned[col_name]
        # RMSE
        _rmse += common.calc_rmse(ref_groom, opt_groom)
        _rmse += common.calc_rmse(ref_loco, opt_loco)
        # Pearson correlation
        _pearson += common.calc_pearson(ref_groom, opt_groom)
        _pearson += common.calc_pearson(ref_loco, opt_loco)

        logging.info(f"RMSE: {_rmse}")
        logging.info(f"Pearson: {_pearson}")

    fig, axs = plt.subplots(1, 2, figsize=(15, 10))
    for col_name in [f"/jointset/joint_LFCoxa/joint_LFCoxa_{dof}/value" for dof in ["yaw", "pitch", "roll"]]:
        ref_groom = groom_kin_ref_aligned[col_name]
        opt_groom = groom_kin_opt_aligned[col_name]

        ref_loco = loco_kin_ref_aligned[col_name]
        opt_loco = loco_kin_opt_aligned[col_name]
        # RMSE
        axs[0].plot(ref_groom, label='ref', color='black')
        axs[0].plot(opt_groom, label='opt')
        axs[0].set_title('Grooming')

        axs[1].plot(ref_loco, label='ref', color='black')
        axs[1].plot(opt_loco, label='opt')
        axs[1].set_title('Locomotion')

    fig.savefig('rmse_{}.png'.format(job))

    msg = msg + '{}\t{}\t'.format(_rmse, (6 - _pearson) * 100)

    msg = msg + "{}\t".format(job)  # write metrics numbered with jobs
    for v in var:  # write muscle parameters
        msg = msg + '{}\t'.format(v)
    msg = msg + '\n'
    logging.info("-------------------- Printing RMSE and Pearson correlation -------------------")

    if msg.strip():
        with open("./rmse{}.mot".format(job), "w") as f:
            f.write(msg)
        logging.info("File written successfully.")
    else:
        logging.warning("Empty message, file not saved.")

    logging.info("job:{} complete.".format(job))
    # seed0.mot contains
    # RMSE, 1 - Pearson, Job ID, muscle parameters


class MyProblem(ea.Problem):
    def __init__(self):
        name = 'NSGA2'
        M = 2  # Objective count
        maxormins = [1] * M  # initialize maxormins（1：minimize target；-1：maximize target）
        Dim = 7*6  # Parameter dimension muscle number * 6
        # 7 muscles in the thorax
        varTypes = [] # initialize varTypes（0：real；1：integer）
        lb = [] # lower bound parameters
        ub = [] # upper bound parameters
        lbin = [] # lower bound include（0 exclude，1 include）
        ubin = [] # upper bound include（0 exclude，1 include）
        for i in range(7):  # 6 para per muscle
            varTypes.extend([0, 0, 0, 0, 0, 0,])
            # Fiso, Vmax, lopt, apx, apy, apz
            lb.extend([100, 5.0, 0.5, -0.02, -0.02, -0.02,])
            ub.extend([1000, 40, 0.9, 0.02, 0.02, 0.02,])
            lbin.extend([1, 1, 1, 1, 1, 1,])
            ubin.extend([1, 1, 1, 1, 1, 1,])

        ea.Problem.__init__(self, name, M, maxormins, Dim, varTypes, lb, ub, lbin, ubin)

    def evalVars(self, Vars): # target
        # backup to log/
        t = time.strftime("%H%M%S")
        if not os.path.exists("log"):
            os.makedirs("log")

        if os.path.exists("opensim.log"):
            os.rename("opensim.log", "log/opensim{}.log".format(t))  # log\\ --> log/
        # osim.Logger.setLevel(3)  #3 warn  4 error

        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            fn1 = 'new{}_states_degrees.mot'.format(i)
            if os.path.exists(fn):
                # Copy
                copy(fn, "log/{}".format(t) + fn)
                # os.rename(fn, "log/{}".format(t) + fn)
            if os.path.exists(fn1):
                copy(fn1, "log/{}".format(t) + fn1)
                # os.rename(fn1, 'log/{}'.format(t) + fn1)

        #os.rename("vars.mot","vars{}.mot".format(t))
        pool = multiprocessing.Pool(processes = multiprocessing.cpu_count() - 2 )  #create process pool; leave 2 cores free
        # SO => FD
        ObjV = []

        for i in range(len(Vars)):
            var = Vars[i]
            # so_fd(i, var, t)
            # jamming, child thread first
            pool.apply_async(try_task, (i, var, t,))
            # ObjV.append(a.get())
        pool.close()  # close pool
        pool.join()  # jam main thread; continue after child threads finished
        CV = None
        for i in range(len(Vars)):
            fn = "rmse{}.mot".format(i)
            if os.path.exists(fn):
                f = open(fn, 'r')
                lines = f.readlines()
                f.close()
                if len(lines) > 0:
                    for line in lines:
                        temp_split = line.split()
                        if len(temp_split) > 2:
                            ObjV.append([float(temp_split[0]), float(temp_split[1]),])
                        else:
                            # Infinity
                            ObjV.append([999, 999,])
                else:
                    logging.warning("not enough lines")
                    ObjV.append([999, 999,])
            else:  # missing files
                logging.warning("missing file")
                ObjV.append([999, 999,])
        O = np.array(ObjV)
        logging.info("Objective values: ", O)
        return O, CV





if __name__ == '__main__':
    eq = True
    behavior = "walking"

    NIND = 120 # Population size
    GEN = 200 # Generation number
    TIMEOUT = 900  # Timeout for each simulation

    if behavior == "walking":
        # _, _, origin_data = common.readMotionFile('1.008-2.356 - ref.mot',[12])  #12 speed
        _, origin_data = common.read_motion_file('locomotion_left_ref.mot')
    else:
        # TODO grooming
        pass

    # origin_data = common.convert_time(origin_data)
    # params = [
    #     94.4000779172051,28.934524897357846,0.8307929001873691,-0.0067876681559818144,-0.01896250840054271,0.01962525758759589,85.62802626811126,39.446087446102815,0.8338687277475718,-0.016514848405765118,-0.00906943673438408,0.001035503763888274,42.41062587888701,32.59825726789659,0.801122747877884,-0.012753326185923692,-0.014282745235926642,0.014375917795786789,552.8116960458075,16.47597965086951,0.7087095853894507,-0.003868513882628114,-0.012090282610869434,0.010416940519316598,719.452373106308,32.25957832424334,0.5132530944450693,-0.013157194597638295,-0.012556755047784193,-0.005287969807966066,376.30372312293576,24.47190320114283,0.7951727489592787,-0.01698209539013724,-0.007603566363654149,-0.012744074115670488,73.3449110614356,34.55048878035477,0.6025996560638969,0.006283232469188586,0.00703262016297467,0.019747279949913807,

    # ]
    # so_fd(0, params, 0) # sanity check

    # from IPython import embed
    # embed()


    problem = MyProblem()
    Encoding = 'RI' # real-integer encoding
    Field = ea.crtfld(Encoding, problem.varTypes, problem.ranges, problem.borders)  # create field descriptor
    population = ea.Population(Encoding, Field, NIND)
    myAlgorithm = ea.moea_NSGA2_templet(problem, population)
    myAlgorithm.MAXGEN = GEN
    myAlgorithm.mutOper.F = 0.7  # mutation
    myAlgorithm.recOper.XOVR = 0.5  # cross-over
    myAlgorithm.logTras = 1  # logging every generation
    myAlgorithm.verbose = True
    myAlgorithm.drawing = 2  # set plot methods（0：no plot；1：result plot；2：evolve of Parator Frontal；3：state space anime）
    noParams = 7*6 # muscles * parameters
    prophetChrom = np.zeros((NIND, noParams))
    if os.path.exists("seed.mot"): # warm start if seed.mot exists
        f = open("seed.mot",'r')
        lines = f.readlines()
        no_lines = len(lines)
        i = 0
        for line in lines:
            temp_split = line.split()
            for j in range(noParams):
                prophetChrom[i][j] = temp_split[j+3]
                # skip RMSE, Pearson, Job ID
            i += 1
            if i >= NIND:
                break

        f.close()
        logging.info("Warm start, get init pops from seed.mot{}".format(prophetChrom.shape))
        prophetPop = ea.Population(Encoding, Field, NIND, prophetChrom)
        myAlgorithm.call_aimFunc(prophetPop)
        [BestIndi, population] = myAlgorithm.run(prophetPop)  # get best individual of last generation; warm start
    else:
        logging.info("No warm start")
        [BestIndi, population] = myAlgorithm.run()  # get best individual of last generation; cold start
    BestIndi.save()

    logging.info('Evals Num：%s' % myAlgorithm.evalsNum)
    logging.info('Time eclapsed %s s' % myAlgorithm.passTime)
    if BestIndi.sizes != 0:
        logging.info('Best target so far：%s' % BestIndi.ObjV[0][0]) # might not indeed be the best; an alternative is to filter manually in last generation
        logging.info('Best parameters so far: ')
        for i in range(BestIndi.Phen.shape[1]):
            logging.info(BestIndi.Phen[0, i])
    else:
        logging.warning('No feasible solution')

