import os
import datetime
import gym
import numpy as np
import itertools
import torch
from sac import SAC
from torch.utils.tensorboard import SummaryWriter

from replay_memory import ReplayMemory, TrajectoryReplayMemory

from model_psd import Psi

from utils_metra import skillsampler, generate_skill_disc, generate_random_radius, compute_cosine_weight
from utils_psd import plot_graph, get_evalbatch, onehot2radius
from utils_sac import VideoRecorder, LatentVideoRecorder, create_directory, copy_files_and_directories

from envs.register import register_custom_envs

from arguments import parser_args


# Load arguments
args = parser_args()

# Logging experiment
exp_num_directory = os.path.join("..", "exps", args.exp_name)
paths_to_copy = [
    "envs",
    "algo/arguments.py",
    "algo/main.py",
    "algo/model_psd.py",
    "algo/model_sac.py",
    "algo/replay_memory.py",
    "algo/sac.py",
    "algo/utils_metra.py",
    "algo/utils_psd.py",
    "algo/utils_sac.py",
    # Add more files or directories as needed
]
create_directory(exp_num_directory)
copy_files_and_directories(paths_to_copy, exp_num_directory)


# Environment
register_custom_envs()
env = gym.make(args.env_name)

# Device
device = torch.device("cuda" if args.cuda else "cpu")

# For reproduce
env.seed(args.seed)
env.action_space.seed(args.seed)   
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# For video
video_directory = '#####TODO#####'

# Tensorboard
writer = SummaryWriter('#####PATH#####')

# Sampler
skillsampler = skillsampler(env, args)

# Memory
memory = ReplayMemory(args.replay_size, args.seed)
memory_traj = TrajectoryReplayMemory(args.traj_replay_size, env._max_episode_steps, args.seed)

# Training Loop
total_numsteps = 0
episode_idx = 0
updates_sac = 0
updates_psi = 0

# Radius dim
radius_input_dim = args.radius_input_dim
radius_latent_dim = args.radius_latent_dim
radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])

# Agent
agent = SAC(env.observation_space.shape[0] + radius_input_dim, env.action_space, args)

# Psi
psi = Psi(env.observation_space.shape[0] + radius_input_dim, args).to(device)

# Check action dim
print("env_name :", args.env_name)
print("max_episode_step :", env._max_episode_steps)
print("state_dim :", env.observation_space.shape[0])
print("radius_latent_dim :", radius_latent_dim)
print("radius_input_dim :", radius_input_dim)
print("radius_bound :", radius_bound)
print("radius_sampling_num :", args.num_intervals)
print("use_adaptive_sampling :", args.use_adaptive_sampling)

# Training Loop
for i_epoch in itertools.count(1):
    for i_episode in range(args.episodes_per_epoch):
        episode_reward = 0
        episode_pseudo_reward = 0
        episode_steps = 0
        episode_idx += 1
        episode_trajectory = []

        done = False
        state = env.reset()
        radius_value, radius_input = skillsampler.sample()

        # radius = np.array([1,0,0])
        state = np.concatenate([state, radius_input])
        
        while not done:
            if args.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy
            
            next_state, reward, done, _ = env.step(action) # Step
            # env.render()

            next_state = np.concatenate([next_state, radius_input])

            psi_diff = psi.forward_np(next_state) - psi.forward_np(state)
            pseudo_reward = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) 
   
            # update
            episode_steps += 1
            total_numsteps += 1
            episode_reward += reward

            if args.use_reward_scheduling == True:
                episode_pseudo_reward += compute_cosine_weight(episode_idx, args.saturation_episode, args.max_weight)*pseudo_reward
            else:
                episode_pseudo_reward += args.rew_weight*pseudo_reward

            mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            memory.push(state, action, reward, radius_value, next_state, mask) # Append transition to memory
            episode_trajectory.append((state, radius_value))
            state = next_state

        # Append transition to trajectory memory
        memory_traj.push(episode_trajectory)

        writer.add_scalar('train/gt_reward', episode_reward, episode_idx)
        writer.add_scalar('train/pseudo_reward', episode_pseudo_reward, episode_idx)
        writer.add_scalar('train/total_reward', episode_reward + episode_pseudo_reward, episode_idx)

        print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {} , pseudo_reward: {}" \
              .format(episode_idx, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_pseudo_reward, 2)))


    if (len(memory) > args.batch_size) and (len(memory_traj) > args.traj_batch_size):
        # Number of updates per step in environment
        for i in range(args.gradient_steps_per_epoch):

            # Update parameters of SAC networks
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, psi, updates_sac, episode_idx, args)

            # SAC Loss
            writer.add_scalar('loss/critic_1', critic_1_loss, updates_sac)
            writer.add_scalar('loss/critic_2', critic_2_loss, updates_sac)
            writer.add_scalar('loss/policy', policy_loss, updates_sac)
            writer.add_scalar('loss/entropy_loss', ent_loss, updates_sac)
            writer.add_scalar('entropy_temprature/alpha', alpha, updates_sac)

            updates_sac += 1

            # Update parameters of Psi networks
            loss_total, loss_max, loss_min, loss_const_1, loss_const_2, flag = psi.update_parameters(memory_traj, args)

            # If a trajectory that meets the conditions is not sampled, do not update
            if flag == False:
                continue

            # Encoder Loss
            writer.add_scalar('psi_loss/total', loss_total, updates_psi)
            writer.add_scalar('psi_loss/max', loss_max, updates_psi)
            writer.add_scalar('psi_loss/min', loss_min, updates_psi)
            writer.add_scalar('psi_loss/const_L', loss_const_1, updates_psi)
            writer.add_scalar('psi_loss/const_1', loss_const_2, updates_psi)

            updates_psi += 1


    # Evaluate skill policy
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        avg_pseudo_reward = 0.
        avg_reward = 0.
        avg_step = 0.
        episodes = args.num_intervals

        ######
        all_rgb_arrays = []
        all_states = []
        ######

        for i in range(episodes): # Make sure that 'episodes' == 'num_intervals', so all cases can be evaluated.

            state = env.reset()
            radius_value, radius_input = skillsampler.sample(current_index=i, eval=True)

            state = np.concatenate([state, radius_input])

            episode_steps = 0
            episode_pseudo_reward = 0
            episode_reward = 0

            done = False
            
            while not done:
                action = agent.select_action(state, evaluate=True)
                next_state, reward, done, _ = env.step(action)

                next_state = np.concatenate([next_state, radius_input])

                # Compute pseudo reward (length)
                psi_diff = psi.forward_np(next_state) - psi.forward_np(state)
                pseudo_reward = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) 

                episode_steps += 1
                episode_reward += reward
                if args.use_reward_scheduling == True:
                    episode_pseudo_reward += compute_cosine_weight(episode_idx, args.saturation_episode, args.max_weight)*pseudo_reward
                else:
                    episode_pseudo_reward += args.rew_weight*pseudo_reward
                state = next_state

                #######
                rgb_array = env.render(mode='rgb_array', camera_id=0)
                encoded_state = psi.forward_np(state)

                all_rgb_arrays.append(rgb_array)
                all_states.append(encoded_state)
                #######

            avg_reward += episode_reward
            avg_pseudo_reward += episode_pseudo_reward
            avg_step += episode_steps

        ############
        from sklearn.decomposition import PCA

        pca = PCA(n_components=2)
        pca_states = pca.fit_transform(np.stack(all_states))
        
        
        if episode_idx % (10*args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
            LatentVideoRecorder(all_rgb_arrays, pca_states, video_directory, episode_idx, fps=args.video_fps)
        ############

        avg_reward /= episodes
        avg_pseudo_reward /= episodes
        avg_step /= episodes

        # For tensorboard
        writer.add_scalar('test/avg_gt_reward', avg_reward, episode_idx)
        writer.add_scalar('test/avg_pseudo_reward', avg_pseudo_reward, episode_idx)
        writer.add_scalar('test/avg_total_reward', avg_reward + avg_pseudo_reward, episode_idx)

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. Reward: {}, Avg. Pseudo Reward: {}, Avg. step: {}"\
              .format(episodes, round(avg_reward, 2), round(avg_pseudo_reward, 2), round(avg_step, 2)))
        print("----------------------------------------")

    # Adaptive sampling
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        radius_bound, average_reward_low, average_reward_high = skillsampler.update_bound(agent, psi)
        writer.add_scalar('sampling/radius_lowerbound', radius_bound[0], episode_idx)
        writer.add_scalar('sampling/radius_upperbound', radius_bound[1], episode_idx)
        writer.add_scalar('sampling/reward_lowerbound', average_reward_low, episode_idx)
        writer.add_scalar('sampling/reward_upperbound', average_reward_high, episode_idx)
        
    if episode_idx % (10*args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        agent.save_checkpoint(args,"{}".format(episode_idx))
        psi.save_checkpoint(args,"{}".format(episode_idx))
        
    if episode_idx > args.num_episode:
        break   
    
env.close()

