# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import os
import sys
import time
import pickle as pkl
import itertools
import random
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils
from agent.encoder import Encoder_Decoder
from agent.actor import DiagGaussianW

from torch.utils.tensorboard import SummaryWriter
from torch import linalg as LA

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig
from torchmetrics import MeanSquaredError
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')
        self.batch_size = cfg.agent.batch_size

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.train_on_mix_gs = cfg.train_on_mix_gs
        self.train_on_unseen_gs = cfg.train_on_unseen_gs

        self.goal_mode = cfg.goal_mode

        if self.goal_mode == 'multi_goal':
            self.all_envs = OrderedDict()
            if self.train_on_unseen_gs:
                # Eval envs have indices 3, 4 and 5
                for i in range(4):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.all_envs.update({env_id: env})
                env_sample_key = list(self.all_envs.keys())[0]
                env_sample = self.all_envs[env_sample_key]
            else:
                for i in range(3):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.all_envs.update({env_id: env})
                env_sample_key = list(self.all_envs.keys())[0]
                env_sample = self.all_envs[env_sample_key]
        else:
            self.goal_id = cfg.goal_id
            env_sample = utils.make_env(cfg, task_id=self.goal_id)
            self.env = env_sample

        obs_dim = env_sample.observation_space.shape[0]
        action_dim = env_sample.action_space.shape[0]
        cfg.agent.obs_dim = obs_dim
        cfg.agent.action_dim = action_dim
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]
        cfg.agent.env_id_dim = 0
        env_sample_goal_shape = (2, )
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        self.agent_expert = hydra.utils.instantiate(cfg.agent,
                                                    _recursive_=False)
        self.agent_random = hydra.utils.instantiate(cfg.agent,
                                                    _recursive_=False)

        self.model_dir = self.work_dir + '/agent_model'
        base_dir = self.work_dir.split('runs')[0] + 'runs/'

        state_features = 14
        self.expert_dir = base_dir + \
            '2021.09.20/reacher_easy_Expert_MultiGoal_experiment=Expert_MultiGoal,goal_mode=multi_goal/seed=1/'

        # state_features = 10
        # self.expert_dir = base_dir + \
        #     '2021.07.01/reacher_easy_Expert_4G_experiment=Expert_4G,goal_mode=multi_goal/seed=1/'

        self.expert_dir = self.expert_dir + 'agent_model'

        # Agent load in weights
        latest_step = utils.get_latest_file(self.expert_dir)
        self.agent_expert.load(self.expert_dir, latest_step)

        # Random policy action noise
        self.policy_noise = 0.5
        self.noise_clip = 0.5
        self.min_action = cfg.agent.action_range[0]
        self.max_action = cfg.agent.action_range[1]

        self.rep_model = cfg.rep_model

        if self.rep_model == 'mlp':
            self.w_joint = utils.PhiWJointMLP(input_dim=state_features,
                                              hidden_dim_1=128,
                                              hidden_dim_2=256,
                                              output_dim=1).to(self.device)
        elif self.rep_model == 'vae':
            self.w_joint = utils.PhiWJointVAE(state_dim=state_features,
                                              hidden_dim_1=750,
                                              hidden_dim_2=750,
                                              latent_dim=32,
                                              output_dim=1).to(self.device)

        self.w_optimizer = torch.optim.Adam(self.w_joint.parameters(), lr=3e-4)

        self.phi_dir = base_dir + \
            '2021.09.20/reacher_easy_RegressPhiJArchMLP_experiment=RegressPhiJArchMLP,goal_mode=multi_goal/seed=1/'
        # self.phi_dir = base_dir + \
        #     '2021.06.30/reacher_easy_TraingPhiJArchMLP4G_experiment=TraingPhiJArchMLP4G/'

        self.phi_dir = self.phi_dir + 'agent_model'
        latest_step = utils.get_latest_file(self.phi_dir)
        self.load_phi(self.phi_dir, latest_step)

        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device,
                                          keep_ids=True)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None)
        self.step = 0

        self.train()

    def train(self, training=True):
        self.training = training
        self.w_joint.train(training)
        # Freeze phi module weights
        # Re-initialize w's for re-training
        if self.rep_model == 'mlp':
            self.w_joint.phi_l1.weight.requires_grad = False
            self.w_joint.phi_l1.bias.requires_grad = False
            self.w_joint.phi_l2.weight.requires_grad = False
            self.w_joint.phi_l2.bias.requires_grad = False
            self.w_joint.phi_l3.weight.requires_grad = False
            self.w_joint.phi_l3.bias.requires_grad = False

        elif self.rep_model == 'vae':
            self.w_joint.e1.weight.requires_grad = False
            self.w_joint.e1.bias.requires_grad = False

            self.w_joint.e2.weight.requires_grad = False
            self.w_joint.e2.bias.requires_grad = False

            self.w_joint.mean.weight.requires_grad = False
            self.w_joint.mean.bias.requires_grad = False

            self.w_joint.log_std.weight.requires_grad = False
            self.w_joint.log_std.bias.requires_grad = False

            self.w_joint.d1.weight.requires_grad = False
            self.w_joint.d1.bias.requires_grad = False

            self.w_joint.d2.weight.requires_grad = False
            self.w_joint.d2.bias.requires_grad = False

            self.w_joint.d3.weight.requires_grad = False
            self.w_joint.d3.bias.requires_grad = False

        nn.init.orthogonal_(self.w_joint.W1.weight.data)
        nn.init.orthogonal_(self.w_joint.W2.weight.data)
        nn.init.orthogonal_(self.w_joint.W3.weight.data)
        nn.init.orthogonal_(self.w_joint.W4.weight.data)

    def train_w(self, obs, action, next_obs, reward, goal, env_ids, writer,
                step):
        obs_action = torch.cat([obs, action, next_obs], dim=-1)
        mean_squared_error = MeanSquaredError()

        if self.rep_model == 'mlp':
            latent_w_joint = self.w_joint(obs_action, env_ids)

            # print('PhiL1', self.w_joint.phi_l1.weight.data)
            # print('W1', self.w_joint.W1.weight.data)
            # print('W2', self.w_joint.W2.weight.data)
            # print('W3', self.w_joint.W3.weight.data)

            rep_loss = F.mse_loss(latent_w_joint, reward)
            self.w_optimizer.zero_grad()
            rep_loss.backward()
            self.w_optimizer.step()

            # print('PhiL1', self.w_joint.phi_l1.weight.data)
            # print('W1', self.w_joint.W1.weight.data)
            # print('W2', self.w_joint.W2.weight.data)
            # print('W3', self.w_joint.W3.weight.data)

            writer.add_scalar('train_rep/loss', rep_loss, step)

            rew_pred = latent_w_joint.detach().cpu()
            mse_error = mean_squared_error(rew_pred, reward.cpu())
            writer.add_scalar('train_rep/mse', mse_error, step)

        elif self.rep_model == 'vae':
            latent_phi_w, mean, std = self.w_joint(obs_action, env_ids)
            z = mean + std * torch.randn_like(std)
            decoded_state = self.w_joint.decode(obs_action, z)
            recon_loss = F.mse_loss(decoded_state, obs_action)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            reward_loss = F.mse_loss(latent_phi_w, reward)
            total_loss = reward_loss + vae_loss

            self.w_optimizer.zero_grad()
            total_loss.backward()
            self.w_optimizer.step()

            writer.add_scalar('train_rep/loss', reward_loss, step)

            rew_pred = latent_phi_w.detach().cpu()
            mse_error = mean_squared_error(rew_pred, reward.cpu())
            writer.add_scalar('train_rep/mse', mse_error, step)

    def update(self, replay_buffer, logger, writer, step):
        obs, action, reward, next_obs, not_done, not_done_no_max, goal, env_ids = replay_buffer.sample(
            self.batch_size)
        self.train_w(obs, action, next_obs, reward, goal, env_ids, writer,
                     step)

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal':
                env_id = random.sample(list(self.all_envs), 1)[0]
                self.env = self.all_envs[env_id]

            obs, goal = self.env.reset(query_goal=True)
            self.agent_expert.reset()
            self.agent_random.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            eps_rewards = []
            eps_reward_preds = []
            expert_agent = random.choice([True, False])
            while not done:
                if expert_agent:
                    with utils.eval_mode(self.agent_expert):
                        action = self.agent_expert.act(obs, goal, sample=False)
                else:
                    with utils.eval_mode(self.agent_random):
                        # Select action according to random policy and add clipped noise
                        action = self.agent_random.act(obs, goal, sample=False)
                        action = torch.tensor(action)
                        noise = (torch.randn_like(action) *
                                 self.policy_noise).clamp(
                                     -self.noise_clip, self.noise_clip)
                        action = (action + noise).clamp(
                            self.min_action, self.max_action)
                        action = action.numpy()

                current_obs = obs
                obs, reward, done, extras = self.env.step(action)
                # evaluate reward preds
                obs_tens = torch.as_tensor(obs, device=self.device).float()
                act_tens = torch.as_tensor(action, device=self.device).float()
                current_obs_tens = torch.as_tensor(current_obs,
                                                   device=self.device).float()
                obs_action = torch.cat([obs_tens, act_tens, current_obs_tens],
                                       dim=-1).unsqueeze(0)
                if self.goal_mode == 'multi_goal':
                    env_id = torch.tensor([env_id]).to(self.device)
                else:
                    env_id = torch.ones(1, 1).to(self.device) * self.goal_id
                # phi(s, a, s')
                if self.rep_model == 'vae':
                    latent_w_joint, _, _ = self.w_joint(obs_action, env_id)
                else:
                    latent_w_joint = self.w_joint(obs_action, env_id)

                eps_rewards.append(reward)
                rew_pred = float(latent_w_joint.detach().cpu())
                eps_reward_preds.append(rew_pred)

                goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')

        mse_error = mean_squared_error(eps_rewards, eps_reward_preds)
        self.logger.log('eval_seen/reward_pred_mse', mse_error, self.step)
        r2_score_ = r2_score(eps_rewards, eps_reward_preds)
        self.logger.log('eval_seen/reward_r2_score', r2_score_, self.step)
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_seen/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)
        self.writer.add_scalar('eval/episode_reward', average_episode_reward,
                               self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # Run evaluation as sanity check
                # also save the phi model
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval_seen/episode', episode, self.step)
                    self.evaluate()
                    self.save(self.model_dir, episode)

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)
                if self.goal_mode == 'multi_goal':
                    env_id = random.sample(list(self.all_envs), 1)[0]
                    self.env = self.all_envs[env_id]
                else:
                    env_id = self.goal_id

                obs, goal = self.env.reset(query_goal=True)
                self.agent_expert.reset()
                self.agent_random.reset()
                expert_agent = random.choice([True, False])
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1
                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            if expert_agent:
                action = self.agent_expert.act(obs, goal, sample=True)
            else:
                # Select action according to policy and add clipped noise
                action = self.agent_random.act(obs, goal, sample=True)
                action = torch.tensor(action)
                noise = (torch.randn_like(action) * self.policy_noise).clamp(
                    -self.noise_clip, self.noise_clip)
                action = (action + noise).clamp(self.min_action,
                                                self.max_action)
                action = action.numpy()

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.update(self.replay_buffer, self.logger, self.writer,
                            self.step)

            next_obs, reward, done, extras = self.env.step(action)
            goal = extras['goal']

            # allow infinite bootstrap
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal, env_id)

            obs = next_obs
            episode_step += 1
            self.step += 1

    def save(self, model_dir, step):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(self.w_joint.state_dict(),
                   '%s/phi_w_joint%s.pt' % (model_dir, step))

        torch.save(self.w_optimizer.state_dict(),
                   '%s/phi_w_joint_optim_%s.pt' % (model_dir, step))

    def load_phi(self, model_dir, step):
        # Load the joint architecture, but will only use phi
        self.w_joint.load_state_dict(
            torch.load('%s/phi_joint%s.pt' % (model_dir, step)))

        self.w_optimizer.load_state_dict(
            torch.load('%s/phi_joint_optim_%s.pt' % (model_dir, step)))


# Reuse the same config as phi
@hydra.main(config_path="config/", config_name="train_rep_w")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from train_w_jointarch import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()