#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import os
import sys
import time
import pickle as pkl

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils
from agent.encoder import Encoder_Decoder

from torch.utils.tensorboard import SummaryWriter

import dmc2gym
import hydra


def make_env(cfg):
    """Helper function to create dm_control environment"""
    if cfg.env == 'ball_in_cup_catch':
        domain_name = 'ball_in_cup'
        task_name = 'catch'
    else:
        domain_name = cfg.env.split('_')[0]
        task_name = '_'.join(cfg.env.split('_')[1:])

    env = dmc2gym.make(domain_name=domain_name,
                       task_name=task_name,
                       seed=cfg.seed,
                       visualize_reward=True)
    env.seed(cfg.seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.agent.name)

        self.writer = SummaryWriter(log_dir='tb')

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)
        self.env = utils.make_env(cfg)

        obs_dim = self.env.observation_space.shape[0]
        action_dim = self.env.action_space.shape[0]
        cfg.agent.params.obs_dim = obs_dim
        cfg.agent.params.action_dim = action_dim
        cfg.agent.params.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
        self.env_goal_shape = (2, )
        cfg.agent.params.goal_dim = self.env_goal_shape[0]
        self.agent = hydra.utils.instantiate(cfg.agent)

        self.model_dir = self.work_dir + '/agent_model'
        self.expert_dir = '/home/melissa/Workspace/nvidia/pytorch_sac_sf/runs/2021.05.11/reacher_easy_single_goal_expert_reach_2G_env=reacher_easy,experiment=expert_reach_2G/agent_model/'

        # Agent load in weights
        latest_step = utils.get_latest_file(self.expert_dir)
        self.agent.load(self.expert_dir, latest_step)

        self.encoder_decoder = Encoder_Decoder(obs_dim,
                                               action_dim).to(self.device)

        self.replay_buffer = ReplayBuffer(self.env.observation_space.shape,
                                          self.env.action_space.shape,
                                          self.env_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)

        self.ed_optimizer = torch.optim.Adam(self.encoder_decoder.parameters(),
                                             lr=3e-4)
        self.step = 0
        self.batch_size = cfg.agent.params.batch_size

    def train_encoder_decoder(self, obs, action, reward, next_obs, writer,
                              step):

        recons_next, recons_reward, recons_action, lat = self.encoder_decoder(
            obs, action)
        ed_loss = F.mse_loss(recons_next, next_obs) + 0.1 * F.mse_loss(
            recons_reward, reward) + F.mse_loss(recons_action, action)

        self.ed_optimizer.zero_grad()
        ed_loss.backward()
        writer.add_scalar('train_encoder/value', ed_loss, step)
        writer.add_scalar('train_reconstructed_rew/value',
                          recons_reward.mean(), step)
        writer.add_scalar('train_actual_rew/value', recons_reward.mean(), step)
        self.ed_optimizer.step()

    def update(self, replay_buffer, logger, writer, step):
        obs, action, reward, next_obs, _, _, _ = replay_buffer.sample(
            self.batch_size)

        self.train_encoder_decoder(obs, action, reward, next_obs, writer, step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)

                # save the model
                if episode % self.cfg.eval_frequency == 0:
                    self.save(self.model_dir, episode)

                obs, goal = self.env.reset(query_goal=True)
                print('Selected goal ', goal)
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            action = self.agent.act(obs, goal, sample=True)

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.update(self.replay_buffer, self.logger, self.writer,
                            self.step)

            next_obs, reward, done, extras = self.env.step(action)
            goal = extras['goal']

            # allow infinite bootstrap
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal)

            obs = next_obs
            episode_step += 1
            self.step += 1

    def save(self, model_dir, step):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(self.encoder_decoder.state_dict(),
                   '%s/encoder_decoder_%s.pt' % (model_dir, step))
        torch.save(self.ed_optimizer.state_dict(),
                   '%s/encoder_optim_%s.pt' % (model_dir, step))


@hydra.main(config_path='config/train_expert.yaml', strict=True)
def main(cfg):
    workspace = Workspace(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
