'''
Inspect result.py

Inspect the resuting trajectory vs benchmark in 10 population of final generation
'''

import os
import time
import logging
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import opensim as osim

import common


today = time.strftime("%Y%m%d")
logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)

def butter_lowpass_filter(data, cutoff, fs=1/5e-4, order=4):
    """
    Apply a low-pass Butterworth filter to the data

    Parameters
    ----------
    data : np.array
        The data to be filtered
    cutoff : float
        The cutoff frequency of the filter
    fs : float
        The sampling frequency of the data
    order : int
        The order of the filter
    """
    from scipy.signal import butter, filtfilt

    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

def plot_ref_vs_optimized(ref_path, opt_path, dof_names, export_path=None, title='', ylabel='Angle (deg)', cutoff=20):
    _, ref_data = common.read_motion_file(ref_path)
    ref_time = ref_data['time']

    # Filter ref data
    for dof_name in dof_names:
        if dof_name == 'time':
            continue
        ref_data[dof_name] = butter_lowpass_filter(ref_data[dof_name], cutoff=cutoff)

    fig, axs = plt.subplots(1, len(dof_names), figsize=(8 * len(dof_names), 5), dpi=300)

    if not isinstance(opt_path, list):
        opt_path = [opt_path]

    axs = [axs] if len(dof_names) == 1 else axs

    for j, dof_name in enumerate(dof_names):
        for i, path in enumerate(opt_path):
            path = path.as_posix() if isinstance(path, Path) else path
            _, opt_data = common.read_motion_file(path)
            opt_time = opt_data['time']
            job_id = path.split('_')[0][-3:]

            axs[j].plot(opt_time, opt_data[dof_name], label='job={}'.format(job_id), alpha=0.7)
        axs[j].plot(ref_time, ref_data[dof_name], label="reference", lw=2, color='black')
        axs[j].set_title(dof_name)
        axs[j].set_xlabel('Time')
    axs[j].legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    axs[0].set_ylabel(ylabel)
    plt.suptitle(title)
    if export_path is not None:
        fig.savefig(export_path, bbox_inches='tight')
    plt.close()


def viz_generation(gen_path, title='', export_path=None):
    # Read line by line
    f = open(gen_path, 'r')
    lines = f.readlines()
    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
    for line in lines:
        job = line.split()[2]
        obj1 = line.split()[0]
        obj2 = line.split()[1]

        ax.scatter(float(obj1), float(obj2), label='job={}'.format(job), alpha=0.8)
    f.close()

    ax.set_xlabel('RMSE')
    ax.set_ylabel('Rev. Pearson')
    ax.set_title(title)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    if export_path is not None:
        fig.savefig(export_path, bbox_inches='tight')
    plt.close()


def plot_muscle_act(
        act_paths, muscles_to_plot=["LFTibia_flex_93434", "LFTibia_extensor_93932"],
        title='', export_path=None):

    muscle_act_dict = {muscle_name: [] for muscle_name in muscles_to_plot}
    fig, axs = plt.subplots(1, len(muscles_to_plot), figsize=(len(muscles_to_plot) * 7, 5), dpi=300)
    for path in act_paths:
        path = path.as_posix() if isinstance(path, Path) else path
        _, data = common.read_motion_file(path)
        for i, muscle_name in enumerate(muscles_to_plot):
            muscle_act_dict[muscle_name].append(data[muscle_name])

            axs[i].plot(data[muscle_name], label=path.split("_")[0][-3:], alpha=0.8)
            axs[i].set_title(muscle_name)
            axs[i].set_xlabel('Time')
    axs[0].set_ylabel('Activation')
    axs[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    plt.suptitle(title)
    if export_path is not None:
        fig.savefig(export_path, bbox_inches='tight')
    plt.close()

    # Save muscle act dict to PKL
    with open(export_path.replace('.png', '.pkl'), 'wb') as f:
        pickle.dump(muscle_act_dict, f)


def select_good_inds(gen_path, ind_th=30):
    with open(gen_path, 'r') as f:
        lines = f.readlines()

    data = [line.split() for line in lines]

    muscle_no = (len(data[0]) - 3) // 6  # Calculate number of muscles
    # Extract objectives and jobs
    inds = np.array([(float(d[0]), float(d[1]), int(d[2])) for d in data])
    # Extract muscle parameters
    muscle_params = {
        f"muscle_{i}": np.array([
            (float(d[3 + i * 6]), float(d[4 + i * 6]), float(d[5 + i * 6]))
            for d in data
        ])
        for i in range(muscle_no)
    }

    # Sort by the product of the first two objective values
    order = np.argsort(inds[:, 0] * inds[:, 1])

    # Apply sorting and threshold
    sorted_inds = inds[order][:ind_th]
    sorted_muscle_params = {k: v[order][:ind_th] for k, v in muscle_params.items()}

    return sorted_inds, sorted_muscle_params


def analyze_feti(cutoff_freq=39):
    """ Analyze the results of the optimization of the FeTi model """

    ref_path = "./FeTi/locomotion_left_ref.mot"
    opt_path = list(Path('./FeTi').rglob('new*_states_degrees.mot'))
    dof_names = ["/jointset/joint_LFTibia/joint_LFTibia_pitch/value"]
    export_path = f"./res_imgs/FeTi_optimization_res_{today}.png"
    title = 'FeTi reference vs. optimized'

    plot_ref_vs_optimized(ref_path, opt_path, dof_names, export_path, title, cutoff=cutoff_freq)

    opt_act_path = list(Path('./FeTi').rglob('new*_controls.sto'))
    plot_muscle_act(opt_act_path, title='FeTi all inds muscle activations',
                    export_path=f"./res_imgs/FeTi_optimization_all_act_{today}.png")

    opt_results = "./FeTi/seed.mot"
    viz_generation(opt_results, title='FeTi optimization results',
                   export_path=f"./res_imgs/FeTi_optimization_gen_{today}.png")

    top_10_inds, top_10_params = select_good_inds(opt_results, ind_th=10)

    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
    for ind in top_10_inds:
        job = ind[2]
        obj1 = ind[0]
        obj2 = ind[1]
        ax.scatter(float(obj1), float(obj2), label='job={}'.format(job), alpha=0.8)
    ax.set_xlabel('RMSE')
    ax.set_ylabel('Rev. Pearson')
    ax.set_title('FeTi top 10 individuals')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    plt.close()
    fig.savefig(f"./res_imgs/FeTi_optimization_top10_{today}.png", bbox_inches='tight')

    # plot parameters of top 10 individuals
    fig, axs = plt.subplots(1, 3, figsize=(9, 3), dpi=300)
    for i, (muscle, params) in enumerate(top_10_params.items()):
        axs[0].scatter([i] * 10, params[:, 0], label='fiso')
        axs[1].scatter([i] * 10, params[:, 1], label='vmax')
        axs[2].scatter([i] * 10, params[:, 2], label='lopt')

    for ax in axs:
        ax.set_xticks(range(len(top_10_params)), [f"muscle_{i}" for i in range(len(top_10_params))])
        ax.set_xlabel('Muscle')

    axs[0].set_title('fiso')
    axs[1].set_title('vmax')
    axs[2].set_title('lopt')

    axs[0].set_ylabel('Value')
    plt.tight_layout()
    plt.suptitle('FeTi top 10 muscle parameters')
    fig.savefig(f"./res_imgs/FeTi_optimization_top10_params_{today}.png", bbox_inches='tight')
    plt.close()

    # from IPython import embed
    # embed()

    # Plot the best 10 muscle results
    best_10_inds = [f"./FeTi/new{int(ind[2])}_states_degrees.mot" for ind in top_10_inds]
    plot_ref_vs_optimized(
        ref_path,
        best_10_inds,
        dof_names,
        cutoff=cutoff_freq,
        export_path=f"./res_imgs/FeTi_optimization_top10_res_{today}.png",
        title='FeTi top 10 individuals')

    best_10_muscle_act = [f"./FeTi/new{int(ind[2])}_controls.sto" for ind in top_10_inds]
    plot_muscle_act(best_10_muscle_act, title='FeTi top 10 muscle activations',
                    export_path=f"./res_imgs/FeTi_optimization_top10_act_{today}.png")

    np.save(f"./res_imgs/FeTi_optimization_top10_inds.npy", top_10_inds)


def analyze_thco(cutoff_freq=20, beh=''):
    """ Analyze the results of the optimization of the ThCo muscles """

    if beh == 'loco':
        ref_path = "./ThCo/locomotion_left_ref.mot"
    elif beh == 'groom':
        ref_path = "./ThCo/antgrooming_left_ref.mot"
    else:
        raise ValueError("Behavior not recognized. Use 'loco' or 'groom'.")

    opt_path = list(Path('./ThCo').rglob(f'new*_{beh}_states_degrees.mot'))
    from IPython import embed
    # embed()
    dof_names = [
        f"/jointset/joint_LFCoxa/joint_LFCoxa_{dof}/value"
        for dof in ["yaw", "pitch", "roll"]
    ]

    export_path = f"./res_imgs/ThCo_optimization_res_{beh}_{today}.png"
    title = 'ThCo reference vs. optimized'

    plot_ref_vs_optimized(ref_path, opt_path, dof_names, export_path, title, cutoff=cutoff_freq)

    opt_act_path = list(Path('./ThCo').rglob(f'new*_{beh}_controls.sto'))
    muscle_names = [
        "LFC_tergopleural_promotor_a",
        "LFC_tergopleural_promotor_b",
        "LFC_pleural_remotor_and_abductor",
        "LFC_pleural_promotor",
        "LFC_sternal_anterior_rotator",
        "LFC_sternal_posterior_rotator",
        "LFC_sternal_adductor"
    ]

    plot_muscle_act(opt_act_path,
                    muscles_to_plot=muscle_names,
                    title='ThCo all inds muscle activations',
                    export_path=f"./res_imgs/ThCo_optimization_all_act_{beh}_{today}.png"
                    )

    opt_results = "./ThCo/seed.mot"
    viz_generation(opt_results, title='ThCo optimization results',
                   export_path=f"./res_imgs/ThCo_optimization_gen_{today}.png")

    top_10_inds, top_10_params = select_good_inds(opt_results, ind_th=10)

    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
    for ind in top_10_inds:
        job = ind[2]
        obj1 = ind[0]
        obj2 = ind[1]
        ax.scatter(float(obj1), float(obj2), label='job={}'.format(job), alpha=0.8)
    ax.set_xlabel('RMSE')
    ax.set_ylabel('Rev. Pearson')
    ax.set_title('ThCo top 10 individuals')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    plt.close()
    fig.savefig(f"./res_imgs/ThCo_optimization_top10_{today}.png", bbox_inches='tight')

    # plot parameters of top 10 individuals
    fig, axs = plt.subplots(1, 3, figsize=(9, 3), dpi=300)
    for i, (muscle, params) in enumerate(top_10_params.items()):
        axs[0].scatter([i] * 10, params[:, 0], label='fiso')
        axs[1].scatter([i] * 10, params[:, 1], label='vmax')
        axs[2].scatter([i] * 10, params[:, 2], label='lopt')

    for ax in axs:
        ax.set_xticks(range(len(top_10_params)), [f"muscle_{i}" for i in range(len(top_10_params))])
        ax.set_xlabel('Muscle')

    axs[0].set_title('fiso')
    axs[1].set_title('vmax')
    axs[2].set_title('lopt')

    axs[0].set_ylabel('Value')
    plt.tight_layout()
    plt.suptitle('ThCo top 10 muscle parameters')
    fig.savefig(f"./res_imgs/ThCo_optimization_top10_params_{today}.png", bbox_inches='tight')
    plt.close()

    # from IPython import embed
    # embed()

    np.save(f"./res_imgs/ThCo_optimization_top10_inds.npy", top_10_inds)

    # Plot the best 10 muscle results
    best_10_inds = [f"./ThCo/new{int(ind[2])}_{beh}_states_degrees.mot" for ind in top_10_inds]
    plot_ref_vs_optimized(
        ref_path,
        best_10_inds,
        dof_names,
        cutoff=cutoff_freq,
        export_path=f"./res_imgs/ThCo_optimization_top10_res_{beh}_{today}.png",
        title='ThCo top 10 individuals')

    best_10_muscle_act = [f"./ThCo/new{int(ind[2])}_controls.sto" for ind in top_10_inds]
    plot_muscle_act(
        best_10_muscle_act,
        title='ThCo top 10 muscle activations',
        muscles_to_plot=muscle_names,
        export_path=f"./res_imgs/ThCo_optimization_top10_act_{beh}_{today}.png"
    )


def analyze_cotr(cutoff_freq=20):
    """ Analyze the results of the optimization of the CoTr muscles """

    ref_path = "./CoTr/locomotion_left_ref.mot"
    opt_path = list(Path('./CoTr').rglob('new*_states_degrees.mot'))
    dof_names = [
        f"/jointset/joint_LFTrochanter/joint_LFTrochanter_{dof}/value"
        for dof in ["yaw", "pitch", "roll"]
    ]

    export_path = f"./res_imgs/CoTr_optimization_res_{today}.png"
    title = 'CoTr reference vs. optimized'

    plot_ref_vs_optimized(ref_path, opt_path, dof_names, export_path, title, cutoff=cutoff_freq)

    opt_act_path = list(Path('./CoTr').rglob('new*_controls.sto'))
    muscle_names = [
        "LFF_sterno-tergo-trochanter_extensor_a",
        "LFF_sterno-tergo-trochanter_extensor_b",
        "LFF_trochanter_extensor",
        "LFF_accesory_trochanter_flexor",
        "LFF_trochanter_flexor",
    ]

    plot_muscle_act(opt_act_path,
                    muscles_to_plot=muscle_names,
                    title='CoTr all inds muscle activations',
                    export_path=f"./res_imgs/CoTr_optimization_all_act_{today}.png"
                    )

    opt_results = "./CoTr/seed.mot"
    viz_generation(opt_results, title='CoTr optimization results',
                   export_path=f"./res_imgs/CoTr_optimization_gen_{today}.png")

    top_10_inds, top_10_params = select_good_inds(opt_results, ind_th=10)

    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
    for ind in top_10_inds:
        job = ind[2]
        obj1 = ind[0]
        obj2 = ind[1]
        ax.scatter(float(obj1), float(obj2), label='job={}'.format(job), alpha=0.8)
    ax.set_xlabel('RMSE')
    ax.set_ylabel('Rev. Pearson')
    ax.set_title('CoTr top 10 individuals')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ncols=4)
    plt.close()
    fig.savefig(f"./res_imgs/CoTr_optimization_top10_{today}.png", bbox_inches='tight')

    # plot parameters of top 10 individuals
    fig, axs = plt.subplots(1, 3, figsize=(9, 3), dpi=300)
    for i, (muscle, params) in enumerate(top_10_params.items()):
        axs[0].scatter([i] * 10, params[:, 0], label='fiso')
        axs[1].scatter([i] * 10, params[:, 1], label='vmax')
        axs[2].scatter([i] * 10, params[:, 2], label='lopt')

    for ax in axs:
        ax.set_xticks(range(len(top_10_params)), [f"muscle_{i}" for i in range(len(top_10_params))])
        ax.set_xlabel('Muscle')

    axs[0].set_title('fiso')
    axs[1].set_title('vmax')
    axs[2].set_title('lopt')

    axs[0].set_ylabel('Value')
    plt.suptitle('CoTr top 10 muscle parameters')
    plt.tight_layout()

    fig.savefig(f"./res_imgs/CoTr_optimization_top10_params_{today}.png", bbox_inches='tight')
    plt.close()

    # from IPython import embed
    # embed()

    np.save(f"./res_imgs/CoTr_optimization_top10_inds.npy", top_10_inds)

    # Plot the best 10 muscle results
    best_10_inds = [f"./CoTr/new{int(ind[2])}_states_degrees.mot" for ind in top_10_inds]
    plot_ref_vs_optimized(
        ref_path,
        best_10_inds,
        dof_names,
        export_path=f"./res_imgs/CoTr_optimization_top10_res_{today}.png",
        title='CoTr top 10 individuals',
        cutoff=cutoff_freq
    )

    best_10_muscle_act = [f"./CoTr/new{int(ind[2])}_controls.sto" for ind in top_10_inds]
    plot_muscle_act(
        best_10_muscle_act,
        title='CoTr top 10 muscle activations',
        muscles_to_plot=muscle_names,
        export_path=f"./res_imgs/CoTr_optimization_top10_act_{today}.png"
    )


if __name__ == "__main__":

    # analyze_feti()
    # analyze_thco(beh='loco')
    analyze_cotr(cutoff_freq=20)