import gym
import numpy as np
import torch
from sac import SAC

import matplotlib.pyplot as plt
from utils_sac import LatentVideoRecorder
from utils_psd import plot_pca_fft
from utils_metra import generate_skill_disc, generate_random_radius
from model_psd import Psi
from sklearn.decomposition import PCA
from envs.register import register_custom_envs
from arguments import parser_args

# Load arguments
args = parser_args()

exp_name = '#####PATH#####'
env_name = 'Humanoid-v3'
num_epi = '#####TODO#####'
args.seed = '#####TODO#####'

# For video
video_directory = '#####PATH#####'

# Device
device = torch.device("cuda" if args.cuda else "cpu")

# Environment
register_custom_envs()
env = gym.make(env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Radius dim
radius_input_dim = args.radius_input_dim
radius_latent_dim = args.radius_latent_dim
radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])

# Agent
agent = SAC(env.observation_space.shape[0] + radius_input_dim, env.action_space, args)
agent.load_checkpoint('#####PATH#####', True)


# Psi
psi = Psi(env.observation_space.shape[0] + radius_input_dim, args).to(device)
psi.load_checkpoint('#####PATH#####', True)

# Check dim
print("state_dim :", env.observation_space.shape[0])

# Num intervals
num_intervals = args.num_intervals

# roll-out epi
episodes = 8

avg_reward = 0.
avg_step = 0.

all_rgb_arrays = []
all_states = []

skill_states = []

for i in range(episodes):
    state = env.reset()
    radius_value, radius_input = generate_random_radius(radius_bound, radius_input_dim, args.num_intervals, current_index=i % args.num_intervals, eval=True)
    state = np.concatenate([state, radius_input])
    
    episode_reward = 0
    step = 0
    done = False

    print("current radius is : ",radius_value)

    states = [] # list for logging

    while not done:

        action = agent.select_action(state, evaluate=True)
        next_state, reward, done, _ = env.step(action)
        # env.render()
        episode_reward += reward
        step += 1
        states.append(state[2:8])
        next_state = np.concatenate([next_state, radius_input])
        state = next_state

        #####
        rgb_array = env.render(mode='rgb_array', camera_id=0)
        encoded_state = psi.forward_np(state)
        all_rgb_arrays.append(rgb_array)
        all_states.append(encoded_state)
        #####
        
    states = np.array(states)
    skill_states.append(states)

    print('episode_reward :' ,reward)
    print('episode_step :' ,step)
    avg_reward += episode_reward
    avg_step += step
    avg_reward += episode_reward
    avg_step += step

############
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca_states = pca.fit_transform(np.stack(all_states))
        
LatentVideoRecorder(all_rgb_arrays, pca_states, video_directory, "test_", fps=args.video_fps)
############