#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import os
import random
import sys
import time
import pickle as pkl
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils
from agent.sac_expert import SACAgent

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig

import pickle


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.train_on_mix_gs = True
        self.fine_tune = cfg.fine_tune

        if self.train_on_mix_gs:
            self.all_envs = OrderedDict()
            for i in range(3, 6):
                env = utils.make_env(cfg, task_id=i)
                env_id = 'G_' + str(i)
                self.all_envs.update({env_id: env})
            env_sample_key = list(self.all_envs.keys())[0]
            env_sample = self.all_envs[env_sample_key]
        else:
            env_sample = utils.make_env(cfg, task_id=1)
            self.env = env_sample

        cfg.agent.obs_dim = env_sample.observation_space.shape[0]
        cfg.agent.action_dim = env_sample.action_space.shape[0]
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]

        env_sample_goal_shape = (2, )
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        self.agent= hydra.utils.instantiate(cfg.agent, _recursive_=False)

        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None)
        self.model_dir = self.work_dir + '/agent_model'

        base_dir = self.work_dir.split('runs')[0] + 'runs/'
        self.expert_dir = base_dir + \
            '2021.06.25/reacher_easy_Expert_3G_experiment=Expert_3G/seed=1/'
        self.expert_dir = self.expert_dir + 'agent_model'
        latest_step = utils.get_latest_file(self.expert_dir)
        self.agent.load(self.expert_dir, latest_step)

        self.step = 0

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.train_on_mix_gs:
                env_sample_key = list(self.all_envs.keys())[0]
                self.env = self.all_envs[random.sample(list(self.all_envs),
                                                       1)[0]]
            obs, goal = self.env.reset(query_goal=True)

            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=False)
                obs, reward, done, extras = self.env.step(action)
                goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval/episode', episode, self.step)
                    self.evaluate()
                    self.agent.save(self.model_dir, episode)

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                if self.train_on_mix_gs:
                    env_sample_key = list(self.all_envs.keys())[0]
                    self.env = self.all_envs[random.sample(
                        list(self.all_envs), 1)[0]]
                obs, goal = self.env.reset(query_goal=True)
                print('Selected goal ', goal)
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, goal, sample=True)

            # run training update
            if self.fine_tune:
                if self.step >= self.cfg.num_seed_steps:
                    self.agent.update(self.replay_buffer, self.logger, self.step)

            next_obs, reward, done, extras = self.env.step(action)
            goal = extras['goal']

            # allow infinite bootstrap
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal)

            obs = next_obs
            episode_step += 1
            self.step += 1


@hydra.main(config_path="config/", config_name="train_expert")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from eval_sac import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()