# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from ast import arg
from matplotlib.pyplot import get
import numpy as np
import random

from utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_cfg
from utils.parse_task import parse_task
from utils.process_sarl import *
from utils.process_marl import process_MultiAgentRL, get_AgentIndex
from utils.process_mtrl import *
from utils.process_metarl import *
import os

from rl_games.common import env_configurations, experiment, vecenv
from rl_games.common.algo_observer import AlgoObserver, IsaacAlgoObserver
from rl_games.torch_runner import Runner
from rl_games.algos_torch import torch_ext
import yaml

# from utils.rl_games_custom import 
from rl_games.common.a2c_common import swap_and_flatten01
import time
from copy import deepcopy

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

def rl_games_save_agent_checkpoint(agent, checkpoint_name, mean_rewards):
    if agent.save_freq > 0:
        if (epoch_num % agent.save_freq == 0) and (mean_rewards[0] <= agent.last_mean_rewards):
            agent.save(os.path.join(agent.nn_dir, 'last_' + checkpoint_name))

    if mean_rewards[0] > agent.last_mean_rewards and epoch_num >= agent.save_best_after:
        print('saving next best rewards: ', mean_rewards)
        agent.last_mean_rewards = mean_rewards[0]
        agent.save(os.path.join(agent.nn_dir, agent.config['name']))

        if 'score_to_win' in agent.config:
            if agent.last_mean_rewards > agent.config['score_to_win']:
                print('Network won!')
                agent.save(os.path.join(agent.nn_dir, checkpoint_name))
                should_exit = True

def rl_games_agent_logs(agent, total_time, epoch_num, step_time, play_time, update_time, training_info, frame, scaled_time, scaled_play_time, curr_frames):
    a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul = training_info
    
    agent.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
    if len(b_losses) > 0:
        agent.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame)

    if agent.has_soft_aug:
        agent.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame)

    if agent.game_rewards.current_size > 0:
        mean_rewards = agent.game_rewards.get_mean()
        mean_lengths = agent.game_lengths.get_mean()
        agent.mean_rewards = mean_rewards[0]

        for i in range(agent.value_size):
            rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i)
            agent.writer.add_scalar(agent.name + '_' + rewards_name + '/step'.format(i), mean_rewards[i], frame)
            agent.writer.add_scalar(agent.name + '_' + rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num)
            agent.writer.add_scalar(agent.name + '_' + rewards_name + '/time'.format(i), mean_rewards[i], total_time)

        agent.writer.add_scalar('episode_lengths/step', mean_lengths, frame)
        agent.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)
        agent.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)

        checkpoint_name = agent.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) + '_agent_' + agent.name
        rl_games_save_agent_checkpoint(agent, checkpoint_name, mean_rewards)

def rl_games_train_epoch(agent_retri, agent_grasp):
    agent_retri.vec_env.set_train_info(agent_retri.frame, agent_retri)
    agent_retri.set_eval()
    agent_grasp.vec_env.set_train_info(agent_grasp.frame, agent_grasp)
    agent_grasp.set_eval()

    play_time_start = time.time()
    with torch.no_grad():
        batch_dict_retri, batch_dict_grasp = rl_games_play_steps(agent_retri, agent_grasp)

    play_time_end = time.time()
    update_time_start = time.time()

    def agent_train_epoch(agent, batch_dict):
        rnn_masks = batch_dict.get('rnn_masks', None)

        agent.set_train()
        agent.curr_frames = batch_dict.pop('played_frames')
        agent.prepare_dataset(batch_dict)
        agent.algo_observer.after_steps()
        if agent.has_central_value:
            agent.train_central_value()

        a_losses = []
        c_losses = []
        b_losses = []
        entropies = []
        kls = []

        for mini_ep in range(0, agent.mini_epochs_num):
            ep_kls = []
            for i in range(len(agent.dataset)):
                a_loss, c_loss, entropy, kl, last_lr, lr_mul, cmu, csigma, b_loss = agent.train_actor_critic(agent.dataset[i])
                a_losses.append(a_loss)
                c_losses.append(c_loss)
                ep_kls.append(kl)
                entropies.append(entropy)
                if agent.bounds_loss_coef is not None:
                    b_losses.append(b_loss)

                agent.dataset.update_mu_sigma(cmu, csigma)
                if agent.schedule_type == 'legacy':
                    av_kls = kl
                    if agent.multi_gpu:
                        dist.all_reduce(kl, op=dist.ReduceOp.SUM)
                        av_kls /= agent.rank_size
                    agent.last_lr, agent.entropy_coef = agent.scheduler.update(agent.last_lr, agent.entropy_coef, agent.epoch_num, 0, av_kls.item())
                    agent.update_lr(agent.last_lr)

            av_kls = torch_ext.mean_list(ep_kls)

            if agent.schedule_type == 'standard':
                agent.last_lr, agent.entropy_coef = agent.scheduler.update(agent.last_lr, agent.entropy_coef, agent.epoch_num, 0, av_kls.item())
                agent.update_lr(agent.last_lr)

            kls.append(av_kls)
            agent.diagnostics.mini_epoch(agent, mini_ep)
            if agent.normalize_input:
                agent.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch

        return a_losses, c_losses, b_losses, entropies, kls, last_lr, lr_mul

    training_info = agent_train_epoch(agent_retri, batch_dict_retri)
    training_info_name = "retri"

    training_info = agent_train_epoch(agent_grasp, batch_dict_grasp)
    training_info_name = "grasp"

    update_time_end = time.time()
    play_time = play_time_end - play_time_start
    update_time = update_time_end - update_time_start
    total_time = update_time_end - play_time_start

    return batch_dict_retri['step_time'], play_time, update_time, total_time, training_info, training_info_name

def rl_games_play_steps(agent_retri, agent_grasp):
    update_list_retri = agent_retri.update_list
    update_list_grasp = agent_grasp.update_list

    step_time = 0.0

    def pre_step(agent, update_list, res_dict, n):
        agent.experience_buffer.update_data('obses', n, agent.obs['obs'])
        agent.experience_buffer.update_data('dones', n, agent.dones)

        for k in update_list:
            agent.experience_buffer.update_data(k, n, res_dict[k]) 
        if agent.has_central_value:
            agent.experience_buffer.update_data('states', n, agent.obs['states'])

    def post_step(agent, res_dict, rewards):
        shaped_rewards = agent.rewards_shaper(rewards)
        if agent.value_bootstrap and 'time_outs' in infos:
            shaped_rewards += agent.gamma * res_dict['values'] * agent.cast_obs(infos['time_outs']).unsqueeze(1).float()

        agent.experience_buffer.update_data('rewards', n, shaped_rewards)

        agent.current_rewards += rewards
        agent.current_lengths += 1
        all_done_indices = agent.dones.nonzero(as_tuple=False)
        env_done_indices = agent.dones.view(agent.num_actors, agent.num_agents).all(dim=1).nonzero(as_tuple=False)

        agent.game_rewards.update(agent.current_rewards[env_done_indices])
        agent.game_lengths.update(agent.current_lengths[env_done_indices])
        agent.algo_observer.process_infos(infos, env_done_indices)

        not_dones = 1.0 - agent.dones.float()

        agent.current_rewards = agent.current_rewards * not_dones.unsqueeze(1)
        agent.current_lengths = agent.current_lengths * not_dones

    for n in range(agent_retri.horizon_length):
        # agent_retri.obs['obs'] = agent_grasp.obs['obs'][:, :81].clone()
        res_dict_retri = agent_retri.get_action_values(agent_retri.obs)
        res_dict_grasp = agent_grasp.get_action_values(agent_grasp.obs)

        pre_step(agent_grasp, update_list_grasp, res_dict_grasp, n)
        pre_step(agent_retri, update_list_retri, res_dict_retri, n)

        step_time_start = time.time()
        concat_action = torch.cat([res_dict_retri['actions'], res_dict_grasp['actions']], dim=0)
        obs, rewards, dones, infos = agent_retri.env_step(concat_action)

        agent_retri.obs['obs'] = infos["retri_obs"][:agent_retri.vec_env.task.retri_num_envs]
        agent_retri.obs['states'] = infos["retri_states"][:agent_retri.vec_env.task.retri_num_envs]
        agent_retri.dones = infos["retri_reset_buf"][:agent_retri.vec_env.task.retri_num_envs]
        retri_rewards = infos["retri_rew_buf"].unsqueeze(1)[:agent_retri.vec_env.task.retri_num_envs]

        agent_grasp.obs['obs'] = obs['obs'][agent_grasp.vec_env.task.retri_num_envs:]
        agent_grasp.obs['states'] = obs['states'][agent_grasp.vec_env.task.retri_num_envs:]
        agent_grasp.dones = dones[agent_grasp.vec_env.task.retri_num_envs:]
        grasp_rewards = rewards[agent_grasp.vec_env.task.retri_num_envs:]

        step_time_end = time.time()

        step_time += (step_time_end - step_time_start)

        post_step(agent_retri, res_dict_retri, retri_rewards)
        post_step(agent_grasp, res_dict_grasp, grasp_rewards)

    def generate_batch_dict(agent):
        last_values = agent.get_values(agent.obs)

        fdones = agent.dones.float()
        mb_fdones = agent.experience_buffer.tensor_dict['dones'].float()
        mb_values = agent.experience_buffer.tensor_dict['values']
        mb_rewards = agent.experience_buffer.tensor_dict['rewards']
        mb_advs = agent.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)
        mb_returns = mb_advs + mb_values

        batch_dict = agent.experience_buffer.get_transformed_list(swap_and_flatten01, agent.tensor_list)
        batch_dict['returns'] = swap_and_flatten01(mb_returns)
        batch_dict['played_frames'] = agent.batch_size
        batch_dict['step_time'] = step_time

        return batch_dict

    batch_dict_retri = generate_batch_dict(agent_retri)
    batch_dict_grasp = generate_batch_dict(agent_grasp)

    return batch_dict_retri, batch_dict_grasp



if __name__ == '__main__':
    set_np_formatting()
    args = get_args(use_rlg_config=True)
    if args.checkpoint == "Base":
        args.checkpoint = ""

    if args.algo == "ppo":
        config_name = "cfg/{}/ppo_continuous.yaml".format(args.algo)
    elif args.algo == "lego":
        config_name = "cfg/{}/ppo_continuous.yaml".format(args.algo)
        if args.task in ["AllegroHandLegoGrasp", "AllegroHandLegoTest"]:
            config_name = "cfg/{}/ppo_continuous_grasp.yaml".format(args.algo)
        if args.task in ["AllegroHandLegoInsert"]:
            config_name = "cfg/{}/ppo_continuous_insert.yaml".format(args.algo)
        if args.task in ["AllegroHandLegoRetrieveGrasp", "AllegroHandLegoRetrieveGraspVValue"]:
            config_name = "cfg/{}/ppo_continuous_retrieve_grasp_v_value.yaml".format(args.algo)

    elif args.algo == "ppo_lstm":
        config_name = "cfg/{}/ppo_continuous_lstm.yaml".format(args.algo)
    else:
        print("We don't support this config in RL-games now")

    args.task_type = "Lego"
    print('Loading config: ', config_name)

    args.cfg_train = config_name
    cfg, cfg_train, logdir = load_cfg(args, use_rlg_config=True)
    sim_params = parse_sim_params(args, cfg, cfg_train)
    cfg_train["seed"] = args.seed
    cfg["env"]["numEnvs"] = args.num_envs

    set_seed(cfg_train.get("seed", -1), cfg_train.get("torch_deterministic", False))

    agent_index = get_AgentIndex(cfg)
    task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index)

    # override
    def override_rlgames_cfg(config_name, args, env, cfg_train, is_grasp):
        with open(config_name, 'r') as stream:
            rlgames_cfg = yaml.safe_load(stream)
            rlgames_cfg['params']['config']['name'] = args.task
            rlgames_cfg['params']['config']['num_actors'] = int(env.num_environments / 2)
            rlgames_cfg['params']['seed'] = cfg_train["seed"]
            rlgames_cfg['params']['config']['env_config']['seed'] = cfg_train["seed"]
            rlgames_cfg['params']['config']['vec_env'] = env
            if is_grasp:
                rlgames_cfg['params']['config']['env_info'] = env.get_grasp_env_info()
                rlgames_cfg['params']['config']['train_dir'] = "runs/grasp"
            else:
                rlgames_cfg['params']['config']['env_info'] = env.get_retri_env_info()
                rlgames_cfg['params']['config']['train_dir'] = "runs/retri"

        return rlgames_cfg

    rl_games_cfg_retri = override_rlgames_cfg("cfg/{}/ppo_continuous_retrieve_grasp_v_value_retri.yaml".format(args.algo), args, env, cfg_train, is_grasp=False)
    rl_games_cfg_grasp = override_rlgames_cfg(config_name, args, env, cfg_train, is_grasp=True)

    vargs = vars(args)
    # runner = Runner()
    # runner.load(rlgames_cfg)
    # runner.reset()
    # runner.run(vargs)

    runner_retri = Runner(algo_observer=IsaacAlgoObserver())
    runner_retri.load(rl_games_cfg_retri)
    runner_retri.reset()

    runner_grasp = Runner(algo_observer=IsaacAlgoObserver())
    runner_grasp.load(rl_games_cfg_grasp)
    runner_grasp.reset()

    from rl_games.torch_runner import _restore, _override_sigma
    import time
    from gym import spaces

    print('Started to train')

    agent_retri = runner_retri.algo_factory.create(runner_retri.algo_name, base_name='retri', params=runner_retri.params)
    agent_grasp = runner_grasp.algo_factory.create(runner_grasp.algo_name, base_name='grasp', params=runner_grasp.params)

    vargs["checkpoint"] = "/home/jmji/DexterousHandEnvs/dexteroushandenvs/runs/AllegroHandLego_15-07-26-59/nn/last_AllegroHandLego_ep_19000_rew_32.662014.pth"
    _restore(agent_retri, vargs)
    _override_sigma(agent_retri, vargs)
    vargs["checkpoint"] = "/home/jmji/DexterousHandEnvs/dexteroushandenvs/runs/AllegroHandLegoRetrieveGraspVValue_19-01-39-15/nn/last_AllegroHandLegoRetrieveGraspVValue_ep_150000_rew_827.7031.pth"
    _restore(agent_grasp, vargs)
    _override_sigma(agent_grasp, vargs)

    agent_retri.init_tensors()
    agent_grasp.init_tensors()

    agent_retri.last_mean_rewards = -100500
    agent_grasp.last_mean_rewards = -100500

    start_time = time.time()
    total_time = 0
    rep_count = 0
    agent_retri.obs = agent_retri.env_reset()
    agent_retri.obs["obs"] = agent_retri.obs["retri_obs"].clone()[:agent_retri.vec_env.task.retri_num_envs]
    agent_retri.obs["states"] = agent_retri.obs["retri_states"].clone()[:agent_retri.vec_env.task.retri_num_envs]
    agent_grasp.obs = agent_grasp.env_reset()
    agent_grasp.obs["obs"] = agent_grasp.obs["obs"][agent_grasp.vec_env.task.retri_num_envs:]
    agent_grasp.obs["states"] = agent_grasp.obs["states"][agent_grasp.vec_env.task.retri_num_envs:]

    agent_retri.curr_frames = agent_retri.batch_size_envs
    agent_grasp.curr_frames = agent_retri.batch_size_envs

    while True:
        epoch_num = agent_retri.update_epoch()
        epoch_num = agent_grasp.update_epoch()

        step_time, play_time, update_time, sum_time, training_info, training_info_name = rl_games_train_epoch(agent_retri=agent_retri, agent_grasp=agent_grasp)

        total_time += sum_time
        frame = agent_retri.frame // agent_retri.num_agents

        # cleaning memory to optimize space
        agent_retri.dataset.update_values_dict(None)
        agent_grasp.dataset.update_values_dict(None)

        should_exit = False

        if agent_retri.rank == 0 or agent_grasp.rank == 0:
            agent_retri.diagnostics.epoch(agent_retri, current_epoch=epoch_num)
            agent_grasp.diagnostics.epoch(agent_grasp, current_epoch=epoch_num)
            # do we need scaled_time?
            scaled_time = agent_retri.num_agents * sum_time
            scaled_play_time = agent_retri.num_agents * play_time
            curr_frames = agent_retri.curr_frames * agent_retri.rank_size if agent_retri.multi_gpu else agent_retri.curr_frames
            agent_retri.frame += curr_frames

            curr_frames = agent_grasp.curr_frames * agent_grasp.rank_size if agent_grasp.multi_gpu else agent_grasp.curr_frames
            agent_grasp.frame += curr_frames

            if agent_retri.print_stats or agent_grasp.print_stats:
                step_time = max(step_time, 1e-6)
                fps_step = curr_frames / step_time
                fps_step_inference = curr_frames / scaled_play_time
                fps_total = curr_frames / scaled_time
                print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{agent_retri.max_epochs}')

                rl_games_agent_logs(agent_retri, total_time, epoch_num, step_time, play_time, update_time, training_info, frame, scaled_time, scaled_play_time, curr_frames)
                rl_games_agent_logs(agent_grasp, total_time, epoch_num, step_time, play_time, update_time, training_info, frame, scaled_time, scaled_play_time, curr_frames)
                # if agent.has_agent_play_config:
                #     agent.agent_play_manager.update(agent)

            if epoch_num >= agent_retri.max_epochs:
                if agent_retri.game_rewards.current_size == 0:
                    print('WARNING: Max epochs reached before any env terminated at least once')
                    mean_rewards = -np.inf
                agent_retri.save(os.path.join(agent_retri.nn_dir, 'last_' + agent_retri.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
                print('MAX EPOCHS NUM!')
                should_exit = True

            update_time = 0


