import os
import datetime
import gym
import numpy as np
import itertools
import torch
from sac import SAC
from torch.utils.tensorboard import SummaryWriter

from replay_memory import ReplayMemory, TrajectoryReplayMemory

from model_psd import Psi
from model_metra import Phi, Lambda

from utils_metra import skillsampler, generate_skill_cont, generate_skill_disc, generate_random_radius, compute_scheduled_weight
from utils_sac import LatentVideoRecorder, create_directory, copy_files_and_directories

from envs.register import register_custom_envs

from arguments import parser_args


# Load arguments
args = parser_args()

# Logging experiment
exp_num_directory = os.path.join("..","sdb","exps", args.exp_name)
paths_to_copy = [
    "envs",
    "algo/arguments.py",
    "algo/main.py",
    "algo/model_metra.py",
    "algo/model_psd.py",
    "algo/model_sac.py",
    "algo/replay_memory.py",
    "algo/sac.py",
    "algo/utils_metra.py",
    "algo/utils_psd.py",
    "algo/utils_sac.py",
    # Add more files or directories as needed
]
create_directory(exp_num_directory)
copy_files_and_directories(paths_to_copy, exp_num_directory)


# Environment
register_custom_envs()
env = gym.make(args.env_name)

# Device
device = torch.device("cuda" if args.cuda else "cpu")

# For reproduce
env.seed(args.seed)
env.action_space.seed(args.seed)   
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# For video
video_directory = 'PATH'

# Tensorboard
writer = SummaryWriter('PATH')

# Sampler
skillsampler = skillsampler(env, args)

# Memory
memory = ReplayMemory(args.replay_size, args.seed)
memory_traj = TrajectoryReplayMemory(args.traj_replay_size, env._max_episode_steps, args.seed)

# Training Loop
total_numsteps = 0
episode_idx = 0
updates_sac = 0
updates_psi = 0
updates_phi = 0

# Radius dim
radius_input_dim = args.radius_input_dim
radius_latent_dim = args.radius_latent_dim
radius_bound = np.array([int(item) for item in args.radius_bound.split(',')])

# Metra dim
metra_dim = args.metra_skill_dim

# XY dim
pos_dim = args.pos_dim

# Agent
agent = SAC(env.observation_space.shape[0] + radius_input_dim + metra_dim, env.action_space, args)

# Psi for PSD
psi = Psi(env.observation_space.shape[0] + radius_input_dim + metra_dim - pos_dim, args).to(device)

# Phi for METRA
phi = Phi(env.observation_space.shape[0] + radius_input_dim, args).to(device)

# Lambda for METRA
lamb = Lambda(args)

# Check action dim
print("env_name :", args.env_name)
print("max_episode_step :", env._max_episode_steps)
print("state_dim :", env.observation_space.shape[0])
print("radius_latent_dim :", radius_latent_dim)
print("radius_input_dim :", radius_input_dim)
print("radius_bound :", radius_bound)
print("radius_sampling_num :", args.num_intervals)


# Training Loop
for i_epoch in itertools.count(1):
    for i_episode in range(args.episodes_per_epoch):
        episode_pseudo_reward_metra = 0
        episode_pseudo_reward_psd = 0
        episode_steps = 0
        episode_idx += 1
        episode_trajectory = []

        done = False
        state = env.reset()

        # Select skills
        radius_value, radius_input = skillsampler.sample()
        metra_skill = generate_skill_cont(metra_dim)

        state = np.concatenate([state, radius_input, metra_skill])
        
        while not done:
            if args.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            next_state, reward, done, _ = env.step(action) # Step
            # env.render()

            next_state = np.concatenate([next_state, radius_input, metra_skill])

            # Compute pseudo reward (PSD)
            psi_diff = psi.forward_np(next_state[pos_dim:]) - psi.forward_np(state[pos_dim:])
            pseudo_reward_psd = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) # ver

            # Compute pseudo reward (METRA)
            pseudo_reward_metra = np.dot(phi.forward_np(next_state) - phi.forward_np(state), metra_skill) 

            # update
            episode_steps += 1
            total_numsteps += 1
            episode_pseudo_reward_metra += pseudo_reward_metra 

            if args.use_reward_scheduling == True:
                episode_pseudo_reward_psd += compute_scheduled_weight(episode_idx, args.saturation_episode, args.min_weight, args.max_weight)*pseudo_reward_psd
            else:
                episode_pseudo_reward_psd += args.rew_weight*pseudo_reward_psd

            mask = 1 if episode_steps == env._max_episode_steps else float(not done)

            # Append DATA to memory
            memory.push(state, action, reward, radius_value, next_state, mask) 
            episode_trajectory.append((state, radius_value))

            state = next_state

        # Append transition to trajectory memory
        memory_traj.push(episode_trajectory)

        writer.add_scalar('train/pseudo_reward_metra', episode_pseudo_reward_metra, episode_idx)
        writer.add_scalar('train/pseudo_reward_psd', episode_pseudo_reward_psd, episode_idx)
        writer.add_scalar('train/total_reward', episode_pseudo_reward_metra + episode_pseudo_reward_psd, episode_idx)

        print("Episode: {}, total numsteps: {}, episode steps: {}, pseudo_reward_metra: {} , pseudo_reward_psd: {}" \
              .format(episode_idx, total_numsteps, episode_steps, round(episode_pseudo_reward_metra, 2), round(episode_pseudo_reward_psd, 2)))


    ## (Option2) : Alternating Updates Method
    if (len(memory) > args.batch_size) and (len(memory_traj) > args.traj_batch_size):
        # Number of updates per step in environment
        for i in range(args.gradient_steps_per_epoch):

            ### Update parameters of SAC networks
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, psi, phi, updates_sac, episode_idx, args)

            # SAC Loss
            writer.add_scalar('sac_loss/critic_1', critic_1_loss, updates_sac)
            writer.add_scalar('sac_loss/critic_2', critic_2_loss, updates_sac)
            writer.add_scalar('sac_loss/policy', policy_loss, updates_sac)
            writer.add_scalar('sac_loss/entropy_loss', ent_loss, updates_sac)
            writer.add_scalar('sac_entropy_temprature/alpha', alpha, updates_sac)

            updates_sac += 1


            ### Update parameters of Phi networks
            phi_loss = phi.update_parameters(memory, args.batch_size, lamb.lambda_value)
            lamb_loss = lamb.update_parameters(memory, args.batch_size, phi)

            # Phi Loss
            writer.add_scalar('phi_loss/phi', phi_loss, updates_phi)
            writer.add_scalar('phi_loss/lambda', lamb_loss, updates_phi)

            updates_phi += 1


            ### Update parameters of Psi networks
            loss_total, loss_max, loss_min, loss_const_1, loss_const_2, flag = psi.update_parameters(memory_traj, args)

            if flag == False: # If a trajectory that meets the conditions is not sampled, do not update
                continue

            # Psi Loss
            writer.add_scalar('psi_loss/total', loss_total, updates_psi)
            writer.add_scalar('psi_loss/max', loss_max, updates_psi)
            writer.add_scalar('psi_loss/min', loss_min, updates_psi)
            writer.add_scalar('psi_loss/const_L', loss_const_1, updates_psi)
            writer.add_scalar('psi_loss/const_1', loss_const_2, updates_psi)

            updates_psi += 1


    # Evaluate skill policy
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        avg_pseudo_reward_psd = 0.
        avg_pseudo_reward_metra = 0.
        avg_step = 0.
        episodes = args.num_intervals*2

        ######
        all_rgb_arrays = []
        all_states_psd = []
        all_states_metra = []
        ######

        for i in range(episodes): # Make sure that 'episodes' == 'num_intervals', so all cases can be evaluated.

            state = env.reset()
            radius_value, radius_input = skillsampler.sample(current_index=i%args.num_intervals, eval=True)
            metra_skill = generate_skill_cont(metra_dim)

            state = np.concatenate([state, radius_input, metra_skill])

            episode_steps = 0
            episode_pseudo_reward_psd = 0
            episode_pseudo_reward_metra = 0

            done = False
            
            while not done:
                action = agent.select_action(state, evaluate=True)
                next_state, reward, done, _ = env.step(action)

                next_state = np.concatenate([next_state, radius_input, metra_skill])

                # Compute pseudo reward (PSD)
                psi_diff = psi.forward_np(next_state[pos_dim:]) - psi.forward_np(state[pos_dim:])
                pseudo_reward_psd = np.exp(-10*np.linalg.norm(radius_value*np.sin(np.pi/(2*radius_value)) - np.linalg.norm(psi_diff))**2) # ver

                # Compute pseudo reward (METRA)
                pseudo_reward_metra = np.dot(phi.forward_np(next_state) - phi.forward_np(state), metra_skill) 
                
                # update
                episode_steps += 1
                episode_pseudo_reward_metra += pseudo_reward_metra 

                if args.use_reward_scheduling == True:
                    episode_pseudo_reward_psd += compute_scheduled_weight(episode_idx, args.saturation_episode, args.min_weight, args.max_weight)*pseudo_reward_psd
                else:
                    episode_pseudo_reward_psd += args.rew_weight*pseudo_reward_psd
                    
                state = next_state

                #######
                rgb_array = env.render(mode='rgb_array', camera_id=0)
                encoded_state_psd = psi.forward_np(state[pos_dim:])
                encoded_state_metra = phi.forward_np(state)

                all_rgb_arrays.append(rgb_array)
                all_states_psd.append(encoded_state_psd)
                all_states_metra.append(encoded_state_metra)
                #######

            avg_pseudo_reward_metra += episode_pseudo_reward_metra
            avg_pseudo_reward_psd += episode_pseudo_reward_psd
            avg_step += episode_steps

        ############
        from sklearn.decomposition import PCA

        pca = PCA(n_components=2)
        pca_states_psd = pca.fit_transform(np.stack(all_states_psd))
        pca1 = PCA(n_components=2)
        pca_states_metra = pca1.fit_transform(np.stack(all_states_metra))

        LatentVideoRecorder(all_rgb_arrays, pca_states_psd, pca_states_metra, video_directory, episode_idx, fps=args.video_fps)
        ############

        avg_pseudo_reward_metra /= episodes
        avg_pseudo_reward_psd /= episodes
        avg_step /= episodes

        # For tensorboard
        writer.add_scalar('test/avg_pseudo_reward_metra', avg_pseudo_reward_metra, episode_idx)
        writer.add_scalar('test/avg_pseudo_reward_psd', avg_pseudo_reward_psd, episode_idx)
        writer.add_scalar('test/avg_total_reward', avg_pseudo_reward_metra + avg_pseudo_reward_psd, episode_idx)

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. METRA Reward: {}, Avg. PSD Reward: {}, Avg. step: {}"\
              .format(episodes, round(avg_pseudo_reward_metra, 2), round(avg_pseudo_reward_psd, 2), round(avg_step, 2)))
        print("----------------------------------------")

    # Adaptive sampling
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        radius_bound, average_reward_low, average_reward_high = skillsampler.update_bound(agent, psi)
        writer.add_scalar('sampling/radius_lowerbound', radius_bound[0], episode_idx)
        writer.add_scalar('sampling/radius_upperbound', radius_bound[1], episode_idx)
        writer.add_scalar('sampling/reward_lowerbound', average_reward_low, episode_idx)
        writer.add_scalar('sampling/reward_upperbound', average_reward_high, episode_idx)
        
    if episode_idx % (args.episodes_per_epoch*args.eval_epoch_ratio) == 0:
        agent.save_checkpoint(args,"{}".format(episode_idx))
        psi.save_checkpoint(args,"{}".format(episode_idx))
        phi.save_checkpoint(args,"{}".format(episode_idx)) 

    if episode_idx > args.num_episode:
        break   
        
env.close()

