from envs.reach_obstacle import FetchReachObstacleEnv
import torch
from rl_modules.gcsl_agent import GCSL
import gymnasium as gym
from envs import register_envs
from mpi_utils.normalizer import normalizer
import argparse
register_envs()

# Other imports and helper functions
import time
import itertools
import numpy as np

# Graphics and plotting.
# print('Installing mediapy:')
# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
# !pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# from IPython.display import clear_output
# clear_output()
# tmp_env = FetchReachObstacleEnv()

parser = argparse.ArgumentParser()
# the environment setting
parser.add_argument('--model_path', type=str, default='', help='the model path to evaluate')
parser.add_argument('--video_prefix', type=str, default='', help='video prefix')
parser.add_argument('--max_episode_steps', type=int, default=50, help='number of timesteps maximum in episode')
parser.add_argument('--num_ep', type=int, default=60, help='number of episodes')
parser.add_argument('--bsa_box_size', type=float, default=None, help='bsa box size leave as None for normal eval')

args = parser.parse_args()



tmp_env = gym.make('FetchReachObstacle-v0', bsa_box_size=args.bsa_box_size)
tmp_env.unwrapped.render_mode = "rgb_array"
env = gym.wrappers.RecordVideo(env=tmp_env, video_folder="videos", name_prefix=args.video_prefix + f"_bsa{args.bsa_box_size}", episode_trigger=lambda x: x % 2 == 0)
env.env._max_episode_steps = args.max_episode_steps

tmp_env._max_episode_steps
def get_env_params(env):
    # obs = env.reset()
    obs, info = env.reset() # new
    # close the environment
    params = {'obs': obs['observation'].shape[0],
            'goal': obs['desired_goal'].shape[0],
            'action': env.action_space.shape[0],
            'action_max': env.action_space.high[0],
            'action_space': env.action_space
            }
    params['max_timesteps'] = env._max_episode_steps
    return params

env_params = get_env_params(tmp_env)
x = torch.load(args.model_path, map_location=torch.device("cpu"))
clip_obs = 200
clip_range = 5
o_norm_mean, o_norm_std, g_norm_mean, g_norm_std, actor_network = x

o_norm = normalizer(size=env_params['obs'], default_clip_range=clip_range)
o_norm.mean, o_norm.std = o_norm_mean, o_norm_std
g_norm = normalizer(size=env_params['goal'], default_clip_range=clip_range)
g_norm.mean, g_norm.std = g_norm_mean, g_norm_std
def preproc_inputs(obs, g, o_norm, g_norm):
    obs_norm = o_norm.normalize(obs)
    g_norm = g_norm.normalize(g)
    # concatenate the stuffs
    inputs = np.concatenate([obs_norm, g_norm])
    inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
    # if self.args.cuda:
    #     inputs = inputs.cuda()
    return inputs
# print(env.unwrapped.goal)
# print(env.unwrapped.pos_obstacle)
# print(env.unwrapped.initial_gripper_xpos)

obs, info = env.reset()
# env.start_video_recorder()
terminated, truncated = False, False
ep_ret, ep_cost, ep_len = 0, 0, 0
costs, normalized_costs, returns, successes = [], [], [], []
i = 0
# Define the camera configuration
for _ in range(5000):
    if i >= args.num_ep:
        break
    assert env.observation_space.contains(obs)
    # act = env.action_space.sample()
    input_tensor = preproc_inputs(obs["observation"], obs["desired_goal"], o_norm, g_norm)
    act = actor_network(torch.tensor(input_tensor, dtype=torch.float32)).detach().numpy().flatten()
    assert env.action_space.contains(act)
    # modified for Safe RL, added cost
    obs, reward, cost, done, info = env.step(act)
    ep_ret += reward
    ep_cost += cost
    ep_len += 1
    env.render()
    if done or info['is_success']: # stop early if goal reached
        # print(i)
        i += 1
        successes.append(info['is_success'])
        costs.append(ep_cost)
        normalized_costs.append(ep_cost/ep_len)
        returns.append(ep_ret)
    
        # reset everything
        ep_ret, ep_cost, ep_len = 0, 0, 0
        obs, info = env.reset()
# env.close_video_recorder()
env.close()
costs
print("Mean Normalized Cost: ", np.mean(normalized_costs))
print("Success Rate: ", np.mean(successes))