#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from builtins import ValueError
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
import sys
import time
import pickle
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig

import metaworld


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.goal_mode = cfg.goal_mode
        self.eval_on_unseen = True
        self.single_task = False
        self.mt_task = cfg.single_task

        if cfg.env == 'ant_walk':
            env_sample_goal_shape = (3, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'reacher_easy':
            env_sample_goal_shape = (2, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'walker_stand' or cfg.env == 'walker_walk' \
            or cfg.env == 'walker_run':
            env_sample_goal_shape = (1, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'metaworld':
            # This is gonna be task id
            env_sample_goal_shape = (1, )
            self.benchmark = 'metaworld'
        else:
            raise ValueError('Invalid benchmark env.')

        self.all_envs = OrderedDict()
        if self.benchmark == 'dm_suite':
            if self.goal_mode == 'multi_goal':
                for i in range(8):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.all_envs.update({env_id: env})
                env_sample_key = list(self.all_envs.keys())[0]
                env_sample = self.all_envs[env_sample_key]
            else:
                env_sample = utils.make_env(cfg, task_id=6)
                self.env = env_sample

            if self.eval_on_unseen:
                self.eval_envs_unseen = OrderedDict()
                for i in range(5, 8):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.eval_envs_unseen.update({env_id: env})
        else:
            # Metaworld
            if not self.single_task:
                ml2 = metaworld.ML2()
                for i, (name, env_cls) in enumerate(ml2.train_classes.items()):
                    env = env_cls()
                    task = random.choice([
                        task for task in ml2.train_tasks
                        if task.env_name == name
                    ])
                    env.set_task(task)
                    env_id = i
                    self.all_envs.update({env_id: env})
                if self.eval_on_unseen:
                    self.eval_envs_unseen = OrderedDict()
                    for i, (name,
                            env_cls) in enumerate(ml2.test_classes.items()):
                        env = env_cls()
                        task = random.choice([
                            task for task in ml2.test_tasks
                            if task.env_name == name
                        ])
                        env.set_task(task)
                        # increment the env_id index to have unique indices
                        # different from training task ids.
                        env_id = i + 2
                        self.eval_envs_unseen.update({env_id: env})

            else:
                ml1 = metaworld.ML1(self.mt_task)
                env = ml1.train_classes[self.mt_task]()
                task = random.choice(ml1.train_tasks)
                env.set_task(task)
                self.all_envs.update({0: env})

                self.eval_envs_unseen = OrderedDict()
                ml2 = metaworld.ML2()
                env = ml2.test_classes['reach-v2-g2']()
                task = random.choice(ml2.test_tasks)
                env.set_task(task)
                env_id = 4
                self.eval_envs_unseen.update({env_id: env})

        if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
            env_sample_key = list(self.all_envs.keys())[0]
            env_sample = self.all_envs[env_sample_key]
        else:
            env_sample = self.env

        # from PIL import Image
        # for env_idx, env in self.all_envs.items():
        #     name = 'Eval_Environment_' + str(env_idx) + '.png'
        #     env.reset()
        #     obs = env.render(offscreen=True, camera_name="corner")
        #     im = Image.fromarray(obs)
        #     im.save(name)


        cfg.agent.obs_dim = env_sample.observation_space.shape[0]
        cfg.agent.action_dim = env_sample.action_space.shape[0]
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]

        cfg.agent.goal_dim = env_sample_goal_shape[0]
        cfg.agent.env_id_dim = 0
        self.agent = hydra.utils.instantiate(cfg.agent, _recursive_=False)

        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None, benchmark=self.benchmark)
        self.model_dir = self.work_dir + '/agent_model'
        # base_dir = self.work_dir.split('runs')[0] + 'runs/'
        # model_dir = base_dir + \
        #     '2021.08.10/metaworld_OnlyPush_env=metaworld,experiment=OnlyPush/seed=1/'
        # latest_step = utils.get_latest_file(self.model_dir)
        # self.agent.load(self.model_dir, latest_step)

        self.step = 0

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                env_id = random.sample(list(self.all_envs), 1)[0]
                # print('Env id ', env_id)
                self.env = self.all_envs[env_id]

            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                # With MT2 goal corresponds to the task id itself.
                # assuming tasks are single goal.
                goal = [env_id]

            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=False)
                obs, reward, done, extras = self.env.step(action)
                # Note: the goal remains the same during episode
                # for metaworld just reuse the same goal as id
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_seen/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)

    def evaluate_unseen(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                env_id = random.sample(list(self.eval_envs_unseen), 1)[0]
                self.env = self.eval_envs_unseen[env_id]
            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                goal = [env_id]

            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=False)
                obs, reward, done, extras = self.env.step(action)
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}_eval_unseen.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_unseen/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        if self.benchmark == 'metaworld':
            reacher_rew = 0
            close_rew = 0
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval_seen/episode', episode, self.step)
                    self.evaluate()
                    if self.eval_on_unseen:
                        self.logger.log('eval_unseen/episode', episode,
                                        self.step)
                        self.evaluate_unseen()
                    if episode % self.cfg.ckpt_frequency == 0:
                        self.agent.save(self.model_dir, episode)

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)

                if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                    env_id = random.sample(list(self.all_envs), 1)[0]
                    self.env = self.all_envs[env_id]

                if self.benchmark == 'metaworld':
                    if env_id == 0:
                        self.writer.add_scalar(
                            'train_task/episode_reacher_reward', reacher_rew,
                            self.step)
                    elif env_id == 1:
                        self.writer.add_scalar(
                            'train_task/episode_close_reward', close_rew,
                            self.step)
                # obs, goal = self.env.reset(query_goal=True)
                if self.benchmark == 'dm_suite':
                    obs, goal = self.env.reset(query_goal=True)
                    # print('Selected goal ', goal)
                elif self.benchmark == 'metaworld':
                    obs = self.env.reset()
                    # With MT2 goal corresponds to the task id itself.
                    # assuming tasks are single goal.
                    goal = [env_id]

                self.agent.reset()
                done = False
                episode_reward = 0
                reacher_rew = 0
                close_rew = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=True)

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.agent.update(self.replay_buffer, self.logger, self.step)

            next_obs, reward, done, extras = self.env.step(action)
            if self.benchmark == 'dm_suite':
                goal = extras['goal']

            if self.benchmark == 'metaworld':
                if (episode_step >= self.env.max_path_length - 1):
                    done = True
            # allow infinite bootstrap
            done = float(done)
            if self.benchmark == 'metaworld':
                done_no_max = 0 if episode_step + 1 == self.env.max_path_length else done
            elif self.benchmark == 'dm_suite':
                done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward
            if self.benchmark == 'metaworld':
                if env_id == 0:
                    reacher_rew += reward
                elif env_id == 1:
                    close_rew += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal)

            obs = next_obs
            episode_step += 1
            self.step += 1

        # Dumps the content of the replay buffer into a file
        # Need to reload it via the ReplayBuffer class
        # file_name = 'policy_in_training_G3_' + str(self.step) + '.pkl'
        # with open(file_name, 'wb') as output:
        #     pickle.dump(self.replay_buffer, output, pickle.HIGHEST_PROTOCOL)


@hydra.main(config_path="config/", config_name="train_expert")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from train_expert import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()