#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import time
import os
import random
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils

from torch.utils.tensorboard import SummaryWriter

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.train_on_mix_gs = True
        self.fine_tune = True

        if self.train_on_mix_gs:
            self.all_envs = OrderedDict()
            for i in range(3, 6):
                env = utils.make_env(cfg, task_id=i)
                env_id = i
                self.all_envs.update({env_id: env})
            env_sample_key = list(self.all_envs.keys())[0]
            env_sample = self.all_envs[env_sample_key]
        else:
            env_sample = utils.make_env(cfg, task_id=1)
            self.env = env_sample

        cfg.agent.obs_dim = env_sample.observation_space.shape[0]
        cfg.agent.action_dim = env_sample.action_space.shape[0]
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]
        env_sample_goal_shape = (2, )
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        self.agent = hydra.utils.instantiate(cfg.agent, _recursive_=False)
        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device,
                                          keep_ids=True)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None)
        self.step = 0

        base_dir = self.work_dir.split('runs')[0] + 'runs/'
        self.phi_w_dir = base_dir + \
            '2021.06.28/reacher_easy_FineTuningWUnseenGs_experiment=FineTuningWUnseenGs,train_on_unseen_gs=true/'
        self.phi_w_dir += 'agent_model/'
        latest_step = utils.get_latest_file(self.phi_w_dir)

        self.agent.load_phi_w(self.phi_w_dir, latest_step)

        self.sf_policy_dir = base_dir + \
            '2021.06.28/reacher_easy_SFJointPWMG_experiment=SFJointPWMG/seed=1/'
        self.sf_policy_dir = self.sf_policy_dir + 'agent_model'
        latest_step = utils.get_latest_file(self.sf_policy_dir)
        self.agent.load(self.sf_policy_dir, latest_step)

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.train_on_mix_gs:
                env_id = random.sample(list(self.all_envs), 1)[0]
                self.env = self.all_envs[env_id]

            obs, goal = self.env.reset(query_goal=True)
            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=False)
                obs, reward, done, extras = self.env.step(action)
                goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)
        self.writer.add_scalar('eval/episode_reward', average_episode_reward,
                               self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval/episode', episode, self.step)
                    self.evaluate()

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)
                if self.train_on_mix_gs:
                    env_id = random.sample(list(self.all_envs), 1)[0]
                    self.env = self.all_envs[env_id]

                obs, goal = self.env.reset(query_goal=True)
                print('Selected goal ', goal)
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=True)

            # run training update
            if self.fine_tune:
                if self.step >= self.cfg.num_seed_steps:
                    self.agent.update(self.replay_buffer, self.logger, self.writer,
                                    self.step)

            next_obs, reward, done, extras = self.env.step(action)
            goal = extras['goal']

            # allow infinite bootstrap
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal, env_id)

            obs = next_obs
            episode_step += 1
            self.step += 1


@hydra.main(config_path="config/", config_name="train_jarch")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from eval_jarch import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
