#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import time
import os
import random
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils

from torch.utils.tensorboard import SummaryWriter

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig

from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import metaworld


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        if cfg.env == 'ant_walk':
            env_sample_goal_shape = (3, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'reacher_easy':
            env_sample_goal_shape = (2, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'metaworld':
            # This is gonna be task id
            env_sample_goal_shape = (0, )
            self.benchmark = 'metaworld'
        else:
            raise ValueError('Invalid benchmark env.')

        self.all_envs = OrderedDict()
        ml2 = metaworld.ML2()
        env = ml2.test_classes['reach-v2-g2']()
        task = random.choice(ml2.test_tasks)
        env.set_task(task)
        env_id = 2
        self.all_envs.update({env_id: env})

        # for i in range(5, 8):
        #     env = utils.make_env(cfg, task_id=i)
        #     env_id = i
        #     self.all_envs.update({env_id: env})
        env_sample_key = list(self.all_envs.keys())[0]
        env_sample = self.all_envs[env_sample_key]

        cfg.agent.obs_dim = env_sample.observation_space.shape[0]
        cfg.agent.action_dim = env_sample.action_space.shape[0]
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]
        env_sample_goal_shape = (1, )
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        cfg.agent.env_id_dim = 0
        self.agent = hydra.utils.instantiate(cfg.agent, _recursive_=False)
        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device,
                                          keep_ids=False)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None, benchmark='metaworld')
        self.step = 0

        base_dir = self.work_dir.split('runs')[0] + 'runs/'

        model_dir = base_dir + \
            '2021.10.02/metaworld_SACReacherDoorClose_env=metaworld,experiment=SACReacherDoorClose/seed=2/'

        self.model_dir = model_dir + 'agent_model/'
        latest_step = utils.get_latest_file(self.model_dir)
        self.agent.load(self.model_dir, latest_step)

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            env_id = random.sample(list(self.all_envs), 1)[0]
            self.env = self.all_envs[env_id]
            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                goal = [env_id]
            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=False)
                obs, reward, done, extras = self.env.step(action)
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_seen/episode_reward', average_episode_reward,
                        self.step)
        # self.logger.dump(self.step)
        self.writer.add_scalar('eval/episode_reward', average_episode_reward,
                               self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval_seen/episode', episode, self.step)
                    self.evaluate()

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)
                env_id = random.sample(list(self.all_envs), 1)[0]
                self.env = self.all_envs[env_id]
                if self.benchmark == 'dm_suite':
                    obs, goal = self.env.reset(query_goal=True)
                    print('Selected goal ', goal)
                else:
                    obs = self.env.reset()
                    goal = [env_id]
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=True)

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.agent.update(self.replay_buffer, self.logger, self.step)

            next_obs, reward, done, extras = self.env.step(action)
            if self.benchmark == 'dm_suite':
                goal = extras['goal']
            if self.benchmark == 'metaworld':
                if (episode_step >= self.env.max_path_length - 1):
                    done = True
            # allow infinite bootstrap
            done = float(done)
            if self.benchmark == 'metaworld':
                done_no_max = 0 if episode_step + 1 == self.env.max_path_length else done
            elif self.benchmark == 'dm_suite':
                done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done

            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal)

            obs = next_obs
            episode_step += 1
            self.step += 1


@hydra.main(config_path="config/", config_name="train_expert")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from fine_tune_sac import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
