import gym
import numpy as np
import torch

import datetime
import matplotlib.pyplot as plt

from sac import SAC
from utils_metra import generate_skill_disc, generate_random_radius
from utils_sac import LatentVideoRecorder_eval

from model_psd import Psi
from sklearn.decomposition import PCA
from envs.register import register_custom_envs
from arguments import parser_args

# Load arguments
args = parser_args()

exp_name = '#####TODO#####'
env_name = 'Walker2d-v3'
num_epi = '#####TODO#####'

video_name = 'test'

# path
video_directory = '#####PATH#####'

# Device
device = torch.device("cuda" if args.cuda else "cpu")

# Environment
register_custom_envs()
env = gym.make(env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Radius dim
radius_input_dim = args.radius_input_dim
radius_latent_dim = args.radius_latent_dim
radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])

# Agent
agent = SAC(env.observation_space.shape[0] + radius_input_dim, env.action_space, args)
agent.load_checkpoint('#####PATH#####', True)

# Psi
psi = Psi(env.observation_space.shape[0] + radius_input_dim, args).to(device)
psi.load_checkpoint('#####PATH#####', True)

# Check dim
print("state_dim :", env.observation_space.shape[0])

# Num intervals
num_intervals = 4


# Evaluation

episodes = 1

all_rgb_arrays = []
all_states = []


for _ in range(episodes):

    state = env.reset()
    radius_value, radius_input = generate_random_radius([30,60], radius_input_dim, num_intervals)
    state = np.concatenate([state, radius_input])
    
    step = 0
    done = False

    while step < 2500:

        if (step/500)%5 == 0:
            radius_value, radius_input = generate_random_radius([30,30], radius_input_dim, num_intervals)
        if (step/500)%5 == 1:
            radius_value, radius_input = generate_random_radius([45,45], radius_input_dim, num_intervals)
        if (step/500)%5 == 2:
            radius_value, radius_input = generate_random_radius([37,37], radius_input_dim, num_intervals)       
        if (step/500)%5 == 3:
            radius_value, radius_input = generate_random_radius([52,52], radius_input_dim, num_intervals)
        if (step/500)%5 == 4:
            radius_value, radius_input = generate_random_radius([40,40], radius_input_dim, num_intervals)       


        action = agent.select_action(state, evaluate=True)
        next_state, reward, done, _ = env.step(action)
        
        next_state = np.concatenate([next_state, radius_input])
        state = next_state

        step += 1

        #######
        rgb_array = env.render(mode='rgb_array', camera_id=0)
        encoded_state = psi.forward_np(state)
        all_rgb_arrays.append(rgb_array)
        all_states.append(encoded_state)
        #######


pca = PCA(n_components=2)
pca_states = pca.fit_transform(np.stack(all_states))

# save video
LatentVideoRecorder_eval(all_rgb_arrays, pca_states, video_directory, video_name, fps=150)