import os
from types import SimpleNamespace
import torch

from common.envs_utils import make_env
from common.sacred_utils import ex
from symmetry.env_utils import get_env_name_for_method
from fatigue.eva_utils import denoise_binary, RecorderEnv, BaseRecorder
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
from scipy.interpolate import interp1d
from scipy import signal, stats
import random
import json


class FatigueRecorder(BaseRecorder):
    def __init__(self, name, env: RecorderEnv):
        self.fatigue_reward = []
        super(FatigueRecorder, self).__init__(name, env)

    def after_reset(self, obs, **kwargs):
        self.fatigue_reward = []

    def after_step(self, observation, reward, done, info, action):
        self.fatigue_reward.append(-info["fat_rew"])


@ex.config
def config():
    net = None
    render = False
    max_steps = 30000
    env_name = ""
    experiment_dir = "."
    assert experiment_dir != "."
    ex.add_config(os.path.join(experiment_dir, "configs.json"))  # loads saved configs
    total_episodes = 20
    max_episode_steps = 1005


@ex.automain
def main(_config):
    args = SimpleNamespace(**_config)
    assert args.env_name != ""

    env_name = get_env_name_for_method(args.env_name, args.mirror_method)

    model_path = args.net or os.path.join(
        args.experiment_dir, "models", "{}_best.pt".format(env_name.replace(':', '_'))
    )

    print("Env: {}".format(env_name))
    print("Model: {}".format(os.path.basename(model_path)))

    actor_critic = torch.load(model_path)

    env = make_env(env_name, args.env_params, render=args.render)
    if hasattr(env.unwrapped, "evaluation_mode"):
        env.unwrapped.evaluation_mode()
    env = RecorderEnv(env, max_episode_steps=args.max_episode_steps)
    env.seed(1093)

    recorder = FatigueRecorder('test', env)

    states = torch.zeros(1, actor_critic.state_size)
    masks = torch.zeros(1, 1)
    obs = env.reset()
    num_episodes = 0
    steps = 0

    fatigue_reward_list = []

    while num_episodes < args.total_episodes and steps < args.max_steps:
        obs = torch.from_numpy(obs).float().unsqueeze(0)

        with torch.no_grad():
            value, action, _, states = actor_critic.act(
                obs, states, masks, deterministic=True
            )
        cpu_actions = action.squeeze().cpu().numpy()

        obs, reward, done, info = env.step(cpu_actions)

        if "Bullet" in args.env_name:
            env.unwrapped._p.resetDebugVisualizerCamera(
                3, 0, -5, env.unwrapped.robot.body_xyz
            )
        steps += 1

        if done:
            if len(recorder.fatigue_reward) >= 1000:
                fatigue_reward_list.append(recorder.fatigue_reward[:1000])
                num_episodes += 1

            obs = env.reset()
    if len(fatigue_reward_list) == 0:
        print("No data!")
        return
    fatigue_reward_arr = np.array(fatigue_reward_list)
    mean_fatigue_reward = fatigue_reward_arr.mean(axis=0)
    np.save(os.path.join(args.experiment_dir, "test_fatigue_reward.npy"), mean_fatigue_reward)
    plt.plot(mean_fatigue_reward)
    plt.savefig(os.path.join(args.experiment_dir, 'test_fatigue_reward.pdf'), bbox_inches='tight')
