#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import time
import os
import random
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils

from torch.utils.tensorboard import SummaryWriter

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig

from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import metaworld


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.goal_mode = cfg.goal_mode
        self.eval_on_unseen = True
        self.single_task = False

        self.mt_task = cfg.single_task
        self.task = cfg.env

        if cfg.env == 'ant_walk':
            env_sample_goal_shape = (3, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'reacher_easy':
            env_sample_goal_shape = (2, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'walker_stand' or cfg.env == 'walker_walk' \
            or cfg.env == 'walker_run':
            env_sample_goal_shape = (1, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'metaworld':
            # This is gonna be task id
            env_sample_goal_shape = (1, )
            self.benchmark = 'metaworld'
        else:
            raise ValueError('Invalid benchmark env.')

        self.all_envs = OrderedDict()
        if self.benchmark == 'dm_suite':
            if self.goal_mode == 'multi_goal':
                for i in range(4):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.all_envs.update({env_id: env})
                env_sample_key = list(self.all_envs.keys())[0]
                env_sample = self.all_envs[env_sample_key]
            else:
                env_sample = utils.make_env(cfg, task_id=6)
                self.env = env_sample

            if self.eval_on_unseen:
                self.eval_envs_unseen = OrderedDict()
                for i in range(4, 6):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.eval_envs_unseen.update({env_id: env})
        else:
            # Metaworld
            if not self.single_task:
                ml2 = metaworld.ML2()
                for i, (name, env_cls) in enumerate(ml2.train_classes.items()):
                    env = env_cls()
                    task = random.choice([
                        task for task in ml2.train_tasks
                        if task.env_name == name
                    ])
                    env.set_task(task)
                    env_id = i
                    self.all_envs.update({env_id: env})
                if self.eval_on_unseen:
                    self.eval_envs_unseen = OrderedDict()
                    for i, (name,
                            env_cls) in enumerate(ml2.test_classes.items()):
                        env = env_cls()
                        task = random.choice([
                            task for task in ml2.test_tasks
                            if task.env_name == name
                        ])
                        env.set_task(task)
                        # increment the env_id index to have unique indices
                        # different from training task ids.
                        env_id = i + 2
                        self.eval_envs_unseen.update({env_id: env})
            else:
                #import pdb;pdb.set_trace()
                ml1 = metaworld.ML1(self.mt_task)
                env = ml1.train_classes[self.mt_task]()
                task = random.choice(ml1.train_tasks)
                env.set_task(task)
                self.all_envs.update({0: env})

                # self.eval_envs_unseen = OrderedDict()
                # ml2 = metaworld.ML2()
                # env = ml2.test_classes['reach-v2']()
                # task = random.choice(ml2.test_tasks)
                # env.set_task(task)
                # env_id = 0
                # self.eval_envs_unseen.update({env_id: env})

        if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
            env_sample_key = list(self.all_envs.keys())[0]
            env_sample = self.all_envs[env_sample_key]
        else:
            env_sample = self.env

        cfg.agent.obs_dim = env_sample.observation_space.shape[0]
        cfg.agent.action_dim = env_sample.action_space.shape[0]
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]
        # This determines the final size of state obs inputs
        # for the agents.
        if self.goal_mode == 'single_goal':
            cfg.agent.env_id_dim = 0
        else:
            if self.benchmark == 'dm_suite':
                cfg.agent.env_id_dim = 0
            else:
                # Mostly needed for metaworld tasks
                cfg.agent.env_id_dim = 1
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        cfg.agent.goal_mode = cfg.goal_mode
        cfg.agent.task = self.task
        self.agent = hydra.utils.instantiate(cfg.agent, _recursive_=False)
        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device,
                                          keep_ids=True)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None, benchmark=self.benchmark)
        self.step = 0

        reload_phi_w = cfg.reload_phi_w
        base_dir = self.work_dir.split('runs')[0] + 'runs/'

        if reload_phi_w:
            load_w = cfg.representation.load_w
            load_phi = cfg.representation.load_phi
            phi_w_model_date = cfg.phi_w_model_date
            if self.benchmark == 'dm_suite':
                if self.task == 'reacher_easy':
                    phi_w_dir = base_dir + phi_w_model_date + '/reacher_easy_'
                elif self.task == 'walker_stand':
                    phi_w_dir = base_dir + phi_w_model_date + '/walker_stand_'
                elif self.task == 'walker_walk':
                    phi_w_dir = base_dir + phi_w_model_date + '/walker_walk_'
                elif self.task == 'walker_run':
                    phi_w_dir = base_dir + phi_w_model_date + '/walker_run_'
                elif self.task == 'ant_walk':
                    phi_w_dir = base_dir + phi_w_model_date + '/ant_walk_'
            elif self.benchmark == 'metaworld':
                phi_w_dir = base_dir + phi_w_model_date + '/metaworld_'
            # These names are hard-coded because the naming should match this otherwise it gets confusing
            # since Hydra doesn't allow to easily pass dir as an argument. Can be fixed later.
            if load_w and load_phi:
                phi_w_dir += 'LearnPhiReachDoorRWall14_env=metaworld,experiment=LearnPhiReachDoorRWall14,representation.latent_size=14/seed=1/'
            elif load_w and not (load_phi):
                phi_w_dir += 'FixPhilearnW_experiment=FixPhilearnW,goal_mode=single_goal,representation.learn_phi=false,representation.learn_w=true/seed=1/'
            elif not (load_w) and load_phi:
                phi_w_dir += 'LearnPhiFixW_experiment=LearnPhiFixW,goal_mode=single_goal,representation.learn_phi=true,representation.learn_w=false/seed=1/'
            else:
                raise ValueError('Must load w or phi.')

            self.phi_w_dir = phi_w_dir + 'agent_model/'
            latest_step = utils.get_latest_file(self.phi_w_dir)
            self.agent.load_phi_w(self.phi_w_dir, latest_step)
        self.model_dir = self.work_dir + '/agent_model'

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                env_id = random.sample(list(self.all_envs), 1)[0]
                self.env = self.all_envs[env_id]
            else:
                env_id = 0
            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                # With MT2 goal corresponds to the task id itself.
                # assuming tasks are single goal.
                goal = [env_id]

            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, env_id, sample=False)
                obs, reward, done, extras = self.env.step(action)
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes

        self.logger.log('eval_seen/episode_reward', average_episode_reward,
                        self.step)
        # self.logger.dump(self.step)
        self.writer.add_scalar('eval/episode_reward', average_episode_reward,
                               self.step)

        self.logger.dump(self.step)

    def evaluate_unseen(self):
        average_episode_reward = 0
        all_rewards = []
        reward_est = []
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                env_id = random.sample(list(self.eval_envs_unseen), 1)[0]
                self.env = self.eval_envs_unseen[env_id]
            else:
                env_id = 0

            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                goal = [env_id]
            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, env_id, sample=False)
                current_obs = obs
                obs, reward, done, extras = self.env.step(action)
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                reward_estimate = self.agent.evaluate_phi_w_approx(
                    current_obs, action, obs, goal)

                all_rewards.append(reward)
                reward_est.append(reward_estimate)
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}_eval_unseen.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_unseen/episode_reward', average_episode_reward,
                        self.step)
        self.writer.add_scalar('eval_unseen/episode_reward',
                               average_episode_reward, self.step)

        mse_error = mean_squared_error(all_rewards, reward_est)
        self.logger.log('eval_unseen/reward_pred_mse', mse_error, self.step)
        r2_score_ = r2_score(all_rewards, reward_est)
        self.logger.log('eval_unseen/reward_r2_score', r2_score_, self.step)
        self.logger.dump(self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        if self.benchmark == 'metaworld':
            reacher_rew = 0
            press_rew = 0
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval_seen/episode', episode, self.step)
                    self.evaluate()
                    if self.eval_on_unseen:
                        self.logger.log('eval_unseen/episode', episode,
                                        self.step)
                        self.evaluate_unseen()
                        self.agent.save(self.model_dir, self.step)

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)

                if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                    env_id = random.sample(list(self.all_envs), 1)[0]
                    self.env = self.all_envs[env_id]
                else:
                    env_id = 0

                if self.benchmark == 'metaworld':
                    if env_id == 0:
                        self.writer.add_scalar(
                            'train_task/episode_reacher_reward', reacher_rew,
                            self.step)
                    elif env_id == 1:
                        self.writer.add_scalar(
                            'train_task/episode_press_reward', press_rew,
                            self.step)
                if self.benchmark == 'dm_suite':
                    obs, goal = self.env.reset(query_goal=True)
                elif self.benchmark == 'metaworld':
                    obs = self.env.reset()
                    goal = [env_id]
                # print('Selected goal ', goal)
                self.agent.reset()
                done = False
                episode_reward = 0
                reacher_rew = 0
                press_rew = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, env_id, sample=True)

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.agent.update(self.replay_buffer, self.logger, self.writer,
                                  self.step)

            next_obs, reward, done, extras = self.env.step(action)
            if self.benchmark == 'dm_suite':
                goal = extras['goal']

            if self.benchmark == 'metaworld':
                if (episode_step >= self.env.max_path_length - 1):
                    done = True
            # allow infinite bootstrap
            done = float(done)
            if self.benchmark == 'metaworld':
                done_no_max = 0 if episode_step + 1 == self.env.max_path_length else done
            elif self.benchmark == 'dm_suite':
                done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward
            if self.benchmark == 'metaworld':
                if env_id == 0:
                    reacher_rew += reward
                elif env_id == 1:
                    press_rew += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal, env_id)

            obs = next_obs
            episode_step += 1
            self.step += 1


@hydra.main(config_path="config/", config_name="train")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from train import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
