import os
import datetime
import gym
import numpy as np
import matplotlib.pyplot as plt
import itertools
import torch
from sac import SAC
from torch.utils.tensorboard import SummaryWriter

from replay_memory import ReplayMemory, TrajectoryReplayMemory

from model_psd import Psi

from utils_metra import generate_random_radius, compute_cosine_weight
from utils_sac import LatentVideoRecorder, create_directory, copy_files_and_directories
from utils_psd import center_crop, random_crop

from envs.register import register_custom_envs

from arguments import parser_args


# Load arguments
args = parser_args()

# Logging experiment
exp_num_directory = os.path.join("..", "sdb", "exps", args.exp_name)
paths_to_copy = [
    "envs",
    "algo/arguments.py",
    "algo/main.py",
    "algo/model_psd.py",
    "algo/model_sac.py",
    "algo/replay_memory.py",
    "algo/sac.py",
    "algo/utils_metra.py",
    "algo/utils_psd.py",
    "algo/utils_sac.py",
    # Add more files or directories as needed
]
create_directory(exp_num_directory)
copy_files_and_directories(paths_to_copy, exp_num_directory)


# Environment
register_custom_envs()
env = gym.make(args.env_name)

# Device
device = torch.device("cuda" if args.cuda else "cpu")

# For reproduce
env.seed(args.seed)
env.action_space.seed(args.seed)   
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# For video
video_directory = 'path'

# Tensorboard
writer = SummaryWriter('path')

# Memory
memory = ReplayMemory(args.replay_size, args.seed)
memory_traj = TrajectoryReplayMemory(args.traj_replay_size, env._max_episode_steps, args.seed)

# Training Loop
total_numsteps = 0
episode_idx = 0
updates_sac = 0
updates_psi = 0

# Radius dim
radius_input_dim = args.radius_input_dim
radius_latent_dim = args.radius_latent_dim
radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])

# Agent
agent = SAC(env.action_space, args)

# Psi
psi = Psi(args).to(device)

# Check action dim
print("env_name :", args.env_name)
print("max_episode_step :", env._max_episode_steps)
print("state_dim :", env.observation_space.shape[0])
print("radius_latent_dim :", radius_latent_dim)
print("radius_input_dim :", radius_input_dim)
print("radius_bound :", radius_bound)
print("radius_sampling_num :", args.num_intervals)

# Training Loop
for i_epoch in itertools.count(1):
    for i_episode in range(args.episodes_per_epoch):
        episode_reward = 0
        episode_pseudo_reward = 0
        episode_steps = 0
        episode_idx += 1
        episode_trajectory = []

        done = False
        state = env.reset()
        state_img = np.transpose(env.render(mode='rgb_array', width=90, height=90, camera_id=1), (2, 0, 1))

        # Initialize buffer
        state_buffer = [state_img] * args.num_stacked_frames  

        radius_value, radius_input = generate_random_radius(radius_bound, radius_input_dim, args.num_intervals)

        while not done:

            # Stack & augment img
            state_img_stacked = np.concatenate(state_buffer, axis=0)        
            state_img_stacked = random_crop(state_img_stacked, args.output_size)
            
            # Select action
            if args.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state_img_stacked, radius_input)  # Sample action from policy

            # Do simulation
            for i in range(args.num_action_repeat):
                next_state, reward, done, _ = env.step(action) # Step

            # Add next img to buffer
            next_state_img = np.transpose(env.render(mode='rgb_array', width=90, height=90, camera_id=1), (2, 0, 1))
            state_buffer.append(next_state_img)
            state_buffer.pop(0)
            
            # Stack & augment next img
            next_state_img_stacked = np.concatenate(state_buffer, axis=0) 
            next_state_img_stacked = random_crop(next_state_img_stacked, args.output_size)       

            # Compute pseudo reward (length)
            psi_diff = psi.forward_np(next_state_img_stacked, radius_input) - psi.forward_np(state_img_stacked, radius_input)
            pseudo_reward = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) # ver2

            # update
            episode_steps += 1
            total_numsteps += 1
            episode_reward += reward

            if args.use_reward_scheduling == True:
                episode_pseudo_reward += compute_cosine_weight(episode_idx, args.saturation_episode, args.max_weight)*pseudo_reward
            else:
                episode_pseudo_reward += args.rew_weight*pseudo_reward

            mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            memory.push(state_img_stacked, action, reward, radius_input, radius_value, next_state_img_stacked, mask) # Append transition to memory
            episode_trajectory.append((state_img_stacked, radius_input, radius_value))

        # Append transition to trajectory memory
        memory_traj.push(episode_trajectory)

        writer.add_scalar('train/gt_reward', episode_reward, episode_idx)
        writer.add_scalar('train/pseudo_reward', episode_pseudo_reward, episode_idx)
        writer.add_scalar('train/total_reward', episode_reward + episode_pseudo_reward, episode_idx)

        print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {} , pseudo_reward: {}" \
              .format(episode_idx, total_numsteps, episode_steps, round(episode_reward, 2), round(episode_pseudo_reward, 2)))


    ## (Option2) : Alternating Updates Method
    if (len(memory) > args.batch_size) and (len(memory_traj) > args.traj_batch_size):
        # Number of updates per step in environment
        for i in range(args.gradient_steps_per_epoch):

            # Update parameters of SAC networks
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, psi, updates_sac, episode_idx, args)

            # SAC Loss
            writer.add_scalar('loss/critic_1', critic_1_loss, updates_sac)
            writer.add_scalar('loss/critic_2', critic_2_loss, updates_sac)
            writer.add_scalar('loss/policy', policy_loss, updates_sac)
            writer.add_scalar('loss/entropy_loss', ent_loss, updates_sac)
            writer.add_scalar('entropy_temprature/alpha', alpha, updates_sac)

            updates_sac += 1

            # Update parameters of Psi networks
            loss_total, loss_max, loss_min, loss_const_1, loss_const_2, flag = psi.update_parameters(memory_traj, args)

            # If a trajectory that meets the conditions is not sampled, do not update
            if flag == False:
                continue
                
            # Encoder Loss
            writer.add_scalar('psi_loss/total', loss_total, updates_psi)
            writer.add_scalar('psi_loss/max', loss_max, updates_psi)
            writer.add_scalar('psi_loss/min', loss_min, updates_psi)
            writer.add_scalar('psi_loss/const_L', loss_const_1, updates_psi)
            writer.add_scalar('psi_loss/const_1', loss_const_2, updates_psi)

            updates_psi += 1


    if episode_idx > args.num_episode:
        break   
        
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        agent.save_checkpoint(args,"{}".format(episode_idx))
        psi.save_checkpoint(args,"{}".format(episode_idx))

    # Evaluate skill policy
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        avg_pseudo_reward = 0.
        avg_reward = 0.
        avg_step = 0.
        episodes = args.num_intervals*2

        ######
        all_rgb_arrays = []
        all_rgb_arrays_train = []
        all_states_eval = []
        ######

        for i in range(episodes): # Make sure that 'episodes' == 'num_intervals', so that all cases can be evaluated.

            episode_steps = 0
            episode_pseudo_reward = 0
            episode_reward = 0

            done = False
            state = env.reset()
            state_img = np.transpose(env.render(mode='rgb_array', width=90, height=90, camera_id=1), (2, 0, 1))

            # Initialize buffer
            state_buffer = [state_img] * args.num_stacked_frames  
            
            # radius_value, radius_input = generate_random_radius(radius_bound, radius_input_dim, args.num_intervals, current_index=i, eval=True)
            radius_value, radius_input = generate_random_radius(radius_bound, radius_input_dim, args.num_intervals, current_index=i%args.num_intervals, eval=True)
            while not done:

                # Stack & augment img
                state_img_stacked = np.concatenate(state_buffer, axis=0)     
                state_img_stacked = center_crop(state_img_stacked, args.output_size)    

                # Select action
                action = agent.select_action(state_img_stacked, radius_input, evaluate=True)  # Sample action from policy

                # Do simulation
                for i in range(args.num_action_repeat):
                    next_state, reward, done, _ = env.step(action) # Step

                # Add next img to buffer
                next_state_img = np.transpose(env.render(mode='rgb_array', width=90, height=90, camera_id=1), (2, 0, 1))
                state_buffer.append(next_state_img)
                state_buffer.pop(0)

                # Stack & augment next img
                next_state_img_stacked = np.concatenate(state_buffer, axis=0)   
                next_state_img_stacked = center_crop(next_state_img_stacked, args.output_size) 

                # Compute pseudo reward (length)
                psi_diff = psi.forward_np(next_state_img_stacked, radius_input) - psi.forward_np(state_img_stacked, radius_input)
                pseudo_reward = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) # ver2

                episode_steps += 1
                episode_reward += reward

                if args.use_reward_scheduling == True:
                    episode_pseudo_reward += compute_cosine_weight(episode_idx, args.saturation_episode, args.max_weight)*pseudo_reward
                else:
                    episode_pseudo_reward += args.rew_weight*pseudo_reward

                #######
                rgb_array = env.render(mode='rgb_array', camera_id=0)
                rgb_array_train = env.render(mode='rgb_array', camera_id=1)
                encoded_state = psi.forward_np(state_img_stacked, radius_input)

                all_rgb_arrays.append(rgb_array)
                all_rgb_arrays_train.append(rgb_array_train)
                all_states_eval.append(encoded_state)
                #######

            avg_reward += episode_reward
            avg_pseudo_reward += episode_pseudo_reward
            avg_step += episode_steps

        ############
        from sklearn.decomposition import PCA

        pca = PCA(n_components=2)
        pca_states = pca.fit_transform(np.stack(all_states_eval))
        
        LatentVideoRecorder(all_rgb_arrays, all_rgb_arrays_train, pca_states, video_directory, episode_idx, fps=args.video_fps)
        ############

        avg_reward /= episodes
        avg_pseudo_reward /= episodes
        avg_step /= episodes

        # For tensorboard
        writer.add_scalar('test/avg_gt_reward', avg_reward, episode_idx)
        writer.add_scalar('test/avg_pseudo_reward', avg_pseudo_reward, episode_idx)
        writer.add_scalar('test/avg_total_reward', avg_reward + avg_pseudo_reward, episode_idx)

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. Reward: {}, Avg. Pseudo Reward: {}, Avg. step: {}"\
              .format(episodes, round(avg_reward, 2), round(avg_pseudo_reward, 2), round(avg_step, 2)))
        print("----------------------------------------")

env.close()

