import os
import time
import json
from glob import glob
from types import SimpleNamespace

current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.dirname(current_dir)
os.sys.path.insert(0, parent_dir)

import torch

from common.envs_utils import make_env
from common.render_utils import StatsVisualizer
from common.sacred_utils import ex

import symmetry.sym_envs
from symmetry.env_utils import get_env_name_for_method
from symmetry.torque_eva_env import TorqueEvaEnv
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal, stats
from scipy.ndimage import gaussian_filter1d


@ex.config
def config():
    net = None
    render = False
    max_steps = 100000
    needed_episodes = 20
    accepted_fail = 10
    env_name = ""
    experiment_dir = "."
    assert experiment_dir != "."
    ex.add_config(os.path.join(experiment_dir, "configs.json"))  # loads saved configs


def get_main_joint_name(env_name):
    if "Cassie" in env_name:  # TODO: doesn't work
        return "hip_flexion"
    elif "Walker2D" in env_name:
        return "thigh"
    elif "Walker3D" in env_name:
        return "hip_y"
    else:
        raise ValueError(
            "Environment %s not supported in evaluation. Please use Walker2D, Walker3D, or Cassie."
        )


@ex.automain
def main(_config):
    args = SimpleNamespace(**_config)
    assert args.env_name != ""

    env_name = args.env_name
    env_name = get_env_name_for_method(args.env_name, args.mirror_method)

    model_path = args.net or os.path.join(
        args.experiment_dir, "models", "{}_best.pt".format(env_name.replace(':', '_'))
    )

    print("Env: {}".format(env_name))
    print("Model: {}".format(os.path.basename(model_path)))

    actor_critic = torch.load(model_path)

    env = make_env(env_name, env_params=args.env_params, render=args.render)

    env = TorqueEvaEnv(env, get_main_joint_name(env_name), 1000)
    env.seed(1093)

    states = torch.zeros(1, actor_critic.state_size)
    masks = torch.zeros(1, 1)
    obs = env.reset()
    torque_list = []
    steps = 0
    episodes = 0
    fail_time = 0
    fail = True
    contact_list = []

    while True:
        obs = torch.from_numpy(obs).float().unsqueeze(0)

        with torch.no_grad():
            value, action, _, states = actor_critic.act(
                obs, states, masks, deterministic=True
            )
        cpu_actions = action.squeeze().cpu().numpy()

        obs, reward, done, info = env.step(cpu_actions)
        contact_list.append(env.unwrapped.robot.feet_contact[0])

        if "Bullet" in args.env_name:
            env.unwrapped._p.resetDebugVisualizerCamera(
                3, 0, -5, env.unwrapped.robot.body_xyz
            )

        if done:
            if "torques" in info:
                torque_list.append(info["torques"])
                episodes += 1
            else:
                fail_time += 1
            obs = env.reset()
        steps += 1

        if episodes >= args.needed_episodes:
            fail = False
            break
        if steps >= args.max_steps:
            print("exceeding max step: ", steps)
            break
        if fail_time > args.accepted_fail:
            print("exceeding fail times: ", fail_time)
            break

    print("succeeded episodes: ", episodes)
    if fail:
        return
    min_len = min([len(torques) for torques in torque_list])
    max_len = max([len(torques) for torques in torque_list])
    torque_list = [torques[:min_len] for torques in torque_list]
    torque_list = np.array(torque_list)
    torque_list = gaussian_filter1d(torque_list, sigma=1, axis=1)
    f, pxx = signal.periodogram(torque_list, axis=1, fs=1.0/env.unwrapped.scene.dt)
    pxx = pxx.mean(axis=0)
    pxx = pxx / pxx.sum(axis=0)
    ent = stats.entropy(pxx, axis=0)

    print(min_len)
    print(ent)
    print(np.mean(ent))
    metric = {"length": min_len,
              "entropy": ent.tolist(),
              "mean_entropy": float(np.mean(ent))}
    with open(os.path.join(args.experiment_dir, "action_psd.json"), "w") as eval_file:
        json.dump(metric, eval_file)
    plt.subplot(121)
    plt.plot(f, pxx[:, 0]/np.sum(pxx[:, 0]))
    plt.subplot(122)
    plt.plot(f, pxx[:, 1]/np.sum(pxx[:, 1]))
    plt.savefig(os.path.join(args.experiment_dir, "action_psd.png"))
