import os
from types import SimpleNamespace
import torch

from common.envs_utils import make_env
from common.sacred_utils import ex
from symmetry.env_utils import get_env_name_for_method
from fatigue.eva_utils import denoise_binary, RecorderEnv, SimpleRecorder
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
from scipy.interpolate import interp1d
from scipy import signal, stats
import random
import json


def get_main_joint_name(env_name):
    if "Walker2D" in env_name:
        return "thigh"
    elif "Walker3D" in env_name:
        return "hip_y"
    else:
        raise ValueError(
            "Environment %s not supported in evaluation. Please use Walker2D, Walker3D, or Cassie."
        )


def get_stride_idx(foot_contact):
    stride_idx = []
    pre_contact = True
    for idx, current_contact in enumerate(foot_contact):
        if not pre_contact and current_contact:
            stride_idx.append(idx)
        pre_contact = current_contact
    return stride_idx


def calc_best_distance(l, r):
    """
    Calculates best L-1 distance between two matrices with the optimal shift (to get rid of the phase difference)

    Arguments:
        l: A Nxd matrix
        r: A Nxd matrix
    """
    best_dist, best_shift = float("inf"), 0
    for shift in range(l.shape[0]):
        dist = np.linalg.norm(l - np.roll(r, shift), ord=1, axis=1).mean()
        if dist < best_dist:
            best_dist, best_shift = dist, shift

    return best_dist, best_shift


def compute_si(values, vectorized=False):
    """
    Computes the symmetric index for the input array

    Arguments:
        values: A Nx2xd array where d is num joints/objects and 2 is for left/right
    """
    l, r = np.array(values).transpose([1, 0, 2])
    if vectorized:

        l = np.concatenate((l * (l >= 0), -l * (l < 0)), axis=1)
        r = np.concatenate((r * (r >= 0), -r * (r < 0)), axis=1)
        # print(l.shape)
        xl = np.sum(l, axis=0)/10.0
        xr = np.sum(r, axis=0)/10.0
        # print(xl.shape)
        return 200 * np.linalg.norm(xl - xr) / (np.linalg.norm(xl) + np.linalg.norm(xr))
    xl = np.linalg.norm(l, axis=1).sum()
    xr = np.linalg.norm(r, axis=1).sum()
    return 200 * abs(xl - xr) / (xl + xr)


def compute_msi(values):
    """
    Computes the modified symmetric index for the input array.
    The array is scaled by the maximum absolute value to make the results scale-invariant.

    Arguments:
        values: A Nx2xd array where d is num joints/objects and 2 is for left/right
    """
    # scale by max |v| in each dimension to make the results scale-invariant
    scale = np.abs(values).max(axis=0).max(axis=0)
    l, r = np.array(values).transpose([1, 0, 2]) / scale

    distance, shift = calc_best_distance(l, r)

    return 2 * distance


def average_values(arrays):
    """
    Takes in a list of lists. It does two things:
      1- makes all the inner lists have the same sizes (using interpolation)
      2- averages the values

    The output has the same length as the longest array in the input list

    Arguments:
        arrays: list of list of number
    """
    max_len = max([len(stride) for stride in arrays])

    # fix the lengths (all will be stretched to `max_len`)
    arrays = [
        interp1d(range(len(arr)), arr)(np.linspace(0, len(arr) - 1, max_len))
        for arr in arrays
    ]

    # smooth out the signal and average over strides
    return (
        gaussian_filter1d(np.concatenate(arrays), sigma=5).reshape((-1, max_len)).mean(axis=0)
    )


@ex.config
def config():
    net = None
    render = False
    max_steps = 30000
    env_name = ""
    experiment_dir = "."
    assert experiment_dir != "."
    ex.add_config(os.path.join(experiment_dir, "configs.json"))  # loads saved configs
    strides = 10
    skip_strides = 2
    total_episodes = 21
    max_episode_steps = 1000


@ex.automain
def main(_config):
    args = SimpleNamespace(**_config)
    assert args.env_name != ""

    env_name = get_env_name_for_method(args.env_name, args.mirror_method)

    model_path = args.net or os.path.join(
        args.experiment_dir, "models", "{}_best.pt".format(env_name.replace(':', '_'))
    )

    print("Env: {}".format(env_name))
    print("Model: {}".format(os.path.basename(model_path)))

    actor_critic = torch.load(model_path)

    env = make_env(env_name, args.env_params, render=args.render)
    if hasattr(env.unwrapped, "evaluation_mode"):
        env.unwrapped.evaluation_mode()
    env = RecorderEnv(env, max_episode_steps=args.max_episode_steps)
    env.seed(1093)

    recorder = SimpleRecorder('test', env, get_main_joint_name(env_name))

    states = torch.zeros(1, actor_critic.state_size)
    masks = torch.zeros(1, 1)
    obs = env.reset()
    num_episodes = 0
    steps = 0

    skip_strides = args.skip_strides
    strides = args.strides
    stride_idx_list = []
    ql_list = []
    qr_list = []
    qdotl_list = []
    qdotr_list = []
    actl_list = []
    actr_list = []
    body_roll_list = []
    rl_torque_list = []
    actions_list = []

    while num_episodes < args.total_episodes and steps < args.max_steps:
        obs = torch.from_numpy(obs).float().unsqueeze(0)

        with torch.no_grad():
            value, action, _, states = actor_critic.act(
                obs, states, masks, deterministic=True
            )
        cpu_actions = action.squeeze().cpu().numpy()

        obs, reward, done, info = env.step(cpu_actions)

        if "Bullet" in args.env_name:
            env.unwrapped._p.resetDebugVisualizerCamera(
                3, 0, -5, env.unwrapped.robot.body_xyz
            )
        steps += 1

        if done:
            foot_contact = recorder.foot_contact
            foot_contact = denoise_binary(foot_contact)
            stride_idx = get_stride_idx(foot_contact)
            if len(stride_idx) > strides + skip_strides:
                stride_idx_list.append(stride_idx[skip_strides:strides + skip_strides + 1])
                ql_list.append(recorder.ql)
                qr_list.append(recorder.qr)
                qdotl_list.append(recorder.qdotl)
                qdotr_list.append(recorder.qdotr)
                actl_list.append(recorder.actl)
                actr_list.append(recorder.actr)
                body_roll_list.append(recorder.body_roll)
                rl_torque_list.append(recorder.rltorques)
                actions_list.append(recorder.actions)

                num_episodes += 1
            obs = env.reset()

    if len(stride_idx_list) == 0:
        print("No data!")
        return

    phase_list = []

    pxx_list = []
    f = None

    roll_var = 0
    total_length = 0.0

    vsi_list = []
    si_list = []

    min_ep_len = min([stride_idx[-1] - stride_idx[0] for stride_idx in stride_idx_list])
    for i in range(len(stride_idx_list)):
        ql = ql_list[i]
        qr = qr_list[i]
        qdotl = qdotl_list[i]
        qdotr = qdotr_list[i]
        actl = actl_list[i]
        actr = actr_list[i]
        stride_idx = stride_idx_list[i]
        # print(stride_idx[-1] - stride_idx[0])

        # process q and qdot to draw phase diagram and compute phase distance
        ql_arr = []
        qr_arr = []
        qdotl_arr = []
        qdotr_arr = []
        for j in range(len(stride_idx) - 1):
            ql_arr.append(ql[stride_idx[j]:stride_idx[j + 1]])
            qr_arr.append(qr[stride_idx[j]:stride_idx[j + 1]])
            qdotl_arr.append(qdotl[stride_idx[j]:stride_idx[j + 1]])
            qdotr_arr.append(qdotr[stride_idx[j]:stride_idx[j + 1]])
        ql_average = average_values(ql_arr)
        qdotl_average = average_values(qdotl_arr)
        qr_average = average_values(qr_arr)
        qdotr_average = average_values(qdotr_arr)
        phase_list.append(np.stack([np.stack([ql_average, qdotl_average]),
                                    np.stack([qr_average, qdotr_average])]).transpose((2, 0, 1)))

        # compute power spectral density for angle, angular velocity and torque
        ql_episodes = ql[stride_idx[0]:stride_idx[0] + min_ep_len]
        qr_episodes = qr[stride_idx[0]:stride_idx[0] + min_ep_len]
        qdotl_episodes = qdotl[stride_idx[0]:stride_idx[0] + min_ep_len]
        qdotr_episodes = qdotr[stride_idx[0]:stride_idx[0] + min_ep_len]
        actl_episodes = actl[stride_idx[0]:stride_idx[0] + min_ep_len]
        actr_episodes = actr[stride_idx[0]:stride_idx[0] + min_ep_len]
        psd_data = np.array([[ql_episodes, qr_episodes],
                             [qdotl_episodes, qdotr_episodes],
                             [actl_episodes, actr_episodes]])
        psd_data = gaussian_filter1d(psd_data, sigma=1, axis=2)
        f, pxx = signal.periodogram(psd_data, axis=2, fs=1.0 / env.unwrapped.scene.dt)
        pxx_list.append(pxx.transpose([2, 0, 1]))

        body_roll = body_roll_list[i]
        body_roll_episodes = np.array(body_roll[stride_idx[0]:stride_idx[-1]])
        roll_var += np.linalg.norm(body_roll_episodes * 8 / np.pi)
        total_length += stride_idx[-1] - stride_idx[0]

        rl_torque_episode = rl_torque_list[i][stride_idx[0]:stride_idx[-1]]
        si_list.append(compute_si(rl_torque_episode, False))
        vsi_list.append(compute_si(rl_torque_episode, True))

    np.save(os.path.join(args.experiment_dir, "phase_list.npy"), np.array(phase_list, dtype=object))
    min_ep_len = min([len(actions) for actions in actions_list])
    actions_list_ = [actions[:min_ep_len] for actions in actions_list]
    np.save(os.path.join(args.experiment_dir, "actions_list.npy"), actions_list_)

    actions_list = np.array(actions_list_)
    m_f_list = np.zeros_like(actions_list)
    for step in range(min_ep_len-1):
        m_f_list[:, step+1, :] = (1.0 - 0.2/60.0) * m_f_list[:, step, :] + 0.2/60.0 * np.abs(actions_list[:, step, :])
    m_f_average = np.mean(m_f_list, axis=2)
    m_f_average = np.mean(m_f_average, axis=0)
    plt.plot(m_f_average)
    plt.savefig(os.path.join(args.experiment_dir, 'mean_fatigue_plot.pdf'), bbox_inches='tight')


    print('The number of valid episodes: ', len(stride_idx_list))
    metric = {'num_episodes': len(stride_idx_list)}

    distance_list = [compute_msi(phase_data) for phase_data in phase_list]
    print('PPI: ', np.median(distance_list))
    metric['PPI'] = {'list': distance_list,
                     'median': np.median(distance_list),
                     'mean': np.mean(distance_list),
                     'std': np.std(distance_list)}

    pxx_list = np.array(pxx_list)
    pxx = pxx_list.mean(axis=0)
    pxx = pxx / pxx.sum(axis=0)
    ent = stats.entropy(pxx, axis=0)
    print('Spectral Entropy: ', ent)
    print('Average Entropy:', ent.mean(axis=1))
    metric['Spectral Entropy'] = {'entropy': ent.tolist(),
                                  'average entropy': ent.mean(axis=1).tolist()}

    peridiogram_list = np.array([pxx.transpose(1, 2, 0), f], dtype=object)
    np.save(os.path.join(args.experiment_dir, "peridiogram.npy"), peridiogram_list)

    print('Roll Variance: ', roll_var / total_length)
    metric['Roll Variance'] = roll_var / total_length

    print('SI: ', np.median(si_list))
    metric['SI'] = {
        'list': si_list,
        'median': np.median(si_list)
    }

    print('VSI: ', np.median(vsi_list))
    metric['VSI'] = {
        'list': vsi_list,
        'median': np.median(vsi_list)
    }

    with open(os.path.join(args.experiment_dir, "eva_metric_111.json"), "w") as eval_file:
        json.dump(metric, eval_file, indent=2)

    # phase plot for all episodes
    plt.figure(0)
    for phase_data in phase_list:
        plt.plot(phase_data[:, 0, 0], phase_data[:, 0, 1], "ro-", markersize=1, linewidth=0.1)
    for phase_data in phase_list:
        plt.plot(phase_data[:, 1, 0], phase_data[:, 1, 1], "go-", markersize=1, linewidth=0.1, alpha=0.7)

    # draw start and end points for at most 20 episodes
    phase_idx = range(len(phase_list)) if len(phase_list) <= 20 else random.choices(range(len(phase_list)), k=20)
    camp = plt.get_cmap('tab20', len(phase_idx))
    for i, idx in enumerate(phase_idx):
        phase_data = phase_list[idx]
        plt.plot(phase_data[0, 0, 0], phase_data[0, 0, 1], "s", c=camp(i), markersize=1.5)
        plt.plot(phase_data[0, 1, 0], phase_data[0, 1, 1], "s", c=camp(i), markersize=1.5)

        plt.plot(phase_data[-1, 0, 0], phase_data[-1, 0, 1], "v", c=camp(i), markersize=1.5)
        plt.plot(phase_data[-1, 1, 0], phase_data[-1, 1, 1], "v", c=camp(i), markersize=1.5)

    plt.xlabel("Hip Flexion Angle")
    plt.ylabel("Hip Flexion Velocity")
    legend_elements = [Line2D([0], [0], linestyle='-', marker='o', color='r', lw=0.5, ms=1),
                       Line2D([0], [0], linestyle='-', marker='o', color='g', lw=0.5, ms=1),
                       Line2D([0], [0], ls='', marker='s', markersize=5),
                       Line2D([0], [0], ls='', marker='v', markersize=5)]
    legend_labels = ['left', 'right', 'stride start', 'stride end']
    plt.legend(legend_elements, legend_labels, fontsize='x-small', loc=4)
    plt.savefig(os.path.join(args.experiment_dir, 'phase_plot.pdf'), bbox_inches='tight')

    # phase plot for episode with median distance
    median_idx = sorted(range(len(distance_list)),
                        key=lambda list_idx: distance_list[list_idx])[int(len(distance_list) / 2)]
    phase_data = phase_list[median_idx]
    plt.figure(1)
    plt.plot(phase_data[:, 0, 0], phase_data[:, 0, 1], "ro-", markersize=1, linewidth=0.1, label="left")
    plt.plot(phase_data[:, 1, 0], phase_data[:, 1, 1], "go-", markersize=1, linewidth=0.1, label="right")
    plt.plot(phase_data[0, 0, 0], phase_data[0, 0, 1], "rs", markersize=2)
    plt.plot(phase_data[0, 1, 0], phase_data[0, 1, 1], "gs", markersize=2)
    plt.plot(phase_data[-1, 0, 0], phase_data[-1, 0, 1], "rv", markersize=2)
    plt.plot(phase_data[-1, 1, 0], phase_data[-1, 1, 1], "gv", markersize=2)
    plt.xlabel("Hip Flexion Angle")
    plt.ylabel("Hip Flexion Velocity")
    plt.legend(legend_elements, legend_labels, fontsize='x-small', loc=4)
    plt.savefig(os.path.join(args.experiment_dir, 'median_phase_plot.pdf'), bbox_inches='tight')

    # periodogram
    cols = ['Angle', 'Angular Velocity', 'Torque']
    rows = ['Left Hip', 'Right Hip']
    fig, axs = plt.subplots(nrows=2, ncols=3, sharex='all', sharey='row', figsize=(8, 4.5))
    plt.setp(axs[1, :], xlabel='Frequency[Hz]')
    plt.setp(axs[:, 0], ylabel='Spectral Density')
    pad = 5  # in points
    for ax, col in zip(axs[0], cols):
        ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                    xycoords='axes fraction', textcoords='offset points',
                    size=12, ha='center', va='baseline')

    for ax, row in zip(axs[:, 0], rows):
        ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                    xycoords=ax.yaxis.label, textcoords='offset points',
                    size=12, ha='right', va='center', rotation=90)
    for i in range(len(rows)):
        for j in range(len(cols)):
            axs[i, j].plot(f[:50], pxx[:50, j, i], linewidth=0.5, color='r')
    fig.tight_layout()
    # fig.subplots_adjust(left=0.15, top=0.95)
    plt.savefig(os.path.join(args.experiment_dir, 'periodogram.pdf'), bbox_inches='tight')
