# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os

from agent import Agent
from agent.actor import DiagGaussianW
from agent.encoder import VAE

import utils

import hydra
from torchmetrics import MeanSquaredError
from torch import linalg as LA


class SACAgent(Agent):
    """SAC-SF algorithm. This is the core algorithm for training a SAC policy
        with successor features. The learned features are kept fixed during
        the policy training. Note: Training single-goal policies as sanity
        check is currently disabled (see the commented code sections). Some
        part of the code assumes fixed features for either \phi or \w and regresses
        the other or keeping both fixed. These features are pre-calculated and
        only work with the modified reward for the reacher task.
    """
    def __init__(self, obs_dim, action_dim, goal_dim, env_id_dim, action_range,
                 device, critic_cfg, actor_cfg, representation_cfg, discount,
                 init_temperature, alpha_lr, alpha_betas, actor_lr,
                 actor_betas, actor_update_frequency, critic_lr, critic_betas,
                 critic_tau, critic_target_update_frequency, batch_size,
                 learnable_temperature, goal_mode, task) -> None:
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature
        self.task = task

        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(
            self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -action_dim

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr,
                                                betas=actor_betas)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr,
                                                 betas=critic_betas)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr,
                                                    betas=alpha_betas)

        self.load_w = representation_cfg.load_w
        self.load_phi = representation_cfg.load_phi
        self.w_model = representation_cfg.w_model
        self.phi_model = representation_cfg.phi_model
        self.goal_mode = goal_mode

        self.append_env_id = False
        if env_id_dim != 0:
            self.append_env_id = True

        if self.goal_mode == 'multi_goal':
            if self.task == 'reacher_easy':
                self.state_features = 14
                self.goal_features = 14
                self.goal_dim = 2
                # self.state_features = 10
                # self.goal_features = 10
                self.latent_state_features = 14
            elif self.task == 'ant_walk':
                # including state obs, actions, next_states
                self.state_features = 326
                self.goal_features = 326
                self.latent_state_features = 28
                self.goal_dim = 3
            elif self.task == 'metaworld':
                self.state_features = 82
                self.goal_dim = 1
                self.latent_state_features = representation_cfg.latent_size
            else:
                raise ValueError('Invalid Env.')
        else:
            if self.task == 'walker_stand' or self.task == 'walker_walk' \
                    or self.task == 'walker_run':
                self.state_features = 54
                self.goal_dim = 1
                self.latent_state_features = representation_cfg.latent_size
            elif self.task == 'metaworld':
                self.state_features = 82
                self.latent_state_features = representation_cfg.latent_size
                self.goal_dim = 1
            else:
                raise ValueError('Invalid task.')

        #######################
        # Learn both phi and w
        ######################
        if self.load_w and self.load_phi:
            if self.phi_model == 'mlp':
                # `Phi` is a MLP
                self.phi = utils.mlp(
                    input_dim=self.state_features,
                    hidden_dim=representation_cfg.phi_hidden_dim,
                    output_dim=self.latent_state_features,
                    hidden_depth=3).to(self.device)

            elif self.phi_model == 'vae':
                self.phi = VAE(state_dim=self.state_features,
                               action_dim=action_dim,
                               latent_dim=self.latent_state_features,
                               device=self.device).to(self.device)

            self.W = utils.mlp(input_dim=self.goal_dim,
                               hidden_dim=representation_cfg.w_hidden_dim,
                               output_dim=self.latent_state_features,
                               hidden_depth=0,
                               weight_norm=representation_cfg.w_norm).to(
                                   self.device)

            # self.W_gt = torch.ones(self.batch_size,
            #                        self.state_features,
            #                        requires_grad=False,
            #                        device=self.device)

        ###################
        # Fix phi, learn w
        ###################
        elif self.load_w and not (self.load_phi):
            # `Phi` is a fixed matrix
            self.phi = torch.ones(self.batch_size,
                                  self.state_features,
                                  requires_grad=False,
                                  device=self.device)

            # `W` is a single layer linear MLP
            self.W = utils.mlp(input_dim=2,
                               hidden_dim=32,
                               output_dim=self.goal_features,
                               hidden_depth=0).to(self.device)

        ###################
        # Fix w, learn phi
        ###################
        elif not (self.load_w) and self.load_phi:
            # `Phi` is a MLP
            self.phi = utils.mlp(input_dim=self.state_features,
                                 hidden_dim=256,
                                 output_dim=self.latent_state_features,
                                 hidden_depth=2).to(self.device)
            # `W` is a fixed vector
            self.W = torch.ones(self.batch_size,
                                self.latent_state_features,
                                requires_grad=False,
                                device=self.device)

        else:
            raise ValueError('Need to learn either w or phi or both.')

        if self.w_model == 'distr' or self.w_model == 'mlp':
            if self.load_w and not (self.load_phi):
                params = list(self.W.parameters())
            elif not (self.load_w) and self.load_phi:
                params = list(self.phi.parameters())
            else:
                # Learning both
                params = list(self.phi.parameters()) + list(
                    self.W.parameters())

        elif self.w_model == 'vector':
            params = [self.W]

        else:
            raise ValueError('Unsupported W model.')

        if self.load_w and self.load_phi:
            for param in self.phi.parameters():
                param.requires_grad = False
            for param in self.W.parameters():
                param.requires_grad = False
        elif not (self.load_w) and self.load_phi:
            for param in self.phi.parameters():
                param.requires_grad = False
        elif self.load_w and not (self.load_phi):
            for param in self.W.parameters():
                param.requires_grad = False

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        if self.load_w and self.load_phi:
            self.W.train(False)
            self.phi.train(False)
        if self.load_w:
            self.W.train(False)
        if self.load_phi:
            self.phi.train(False)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, goal, env_id=None, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        goal = torch.FloatTensor(goal).to(self.device)
        obs = obs.unsqueeze(0)
        goal = goal.unsqueeze(0)

        if self.append_env_id:
            env_id = torch.FloatTensor([env_id]).to(self.device).unsqueeze(0)
            dist = self.actor(obs, goal, env_id)
        else:
            dist = self.actor(obs, goal)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, obs, action, reward, next_obs, not_done, goal,
                      env_id, logger, writer, step):
        if self.append_env_id:
            dist = self.actor(next_obs, goal, env_id)
        else:
            dist = self.actor(next_obs, goal)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action, goal,
                                                  env_id)

        target_q1_sum = torch.sum(target_Q1, 1).unsqueeze(-1)
        target_q2_sum = torch.sum(target_Q2, 1).unsqueeze(-1)
        target_Q_min = torch.min(target_q1_sum, target_q2_sum)
        mask_q1 = torch.eq(target_Q_min, target_q1_sum)
        mask_q2 = torch.eq(target_Q_min, target_q2_sum)
        target_q1_min = (target_Q1 * mask_q1) + (target_Q2 * mask_q2)
        target_V = target_q1_min - self.alpha.detach() * log_prob

        if self.load_w and self.load_phi:
            # if self.goal_mode == 'multi_goal':
            obs_action = torch.cat([obs, action, next_obs], dim=-1)

            if self.phi_model == 'mlp':
                phi_latent = self.phi(obs_action)
            elif self.phi_model == 'vae':
                phi_latent, _, _ = self.phi(obs_action)

            latent_phi_w = torch.bmm(
                phi_latent.view(self.batch_size, 1,
                                self.latent_state_features),
                self.W(goal).view(self.batch_size, self.latent_state_features,
                                  1))
            # else:
            #     obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
            #     obs_wt_goal[:, 0:2] = obs[:, 0:2]
            #     obs_wt_goal[:, 2:] = obs[:, 4:]
            #     latent_phi_w = torch.bmm(
            #         self.phi(obs_wt_goal).view(self.batch_size, 1,
            #                                    self.latent_state_features),
            #         self.W(goal).view(self.batch_size,
            #                           self.latent_state_features, 1))

            latent_phi_w = latent_phi_w.squeeze(-1)

            # phi = torch.ones_like(self.W_gt)
            # phi[:, 1:3] = obs[:, 0:2]
            # phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            # W = torch.ones_like(self.W_gt)
            # W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            # W[:, 1:3] = 2 * goal
            # W[:, 3:] = -torch.ones_like(W[:, 3:])

        elif self.load_w and not (self.load_phi):
            self.phi[:, 1:3] = obs[:, 0:2]
            self.phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            W = torch.ones_like(self.phi)
            W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            W[:, 1:3] = 2 * goal
            W[:, 3:] = -torch.ones_like(W[:, 3:])

            if self.w_model == 'mlp':
                latent_phi_w = torch.bmm(
                    self.phi.view(self.batch_size, 1, 4),
                    self.W(goal).view(self.batch_size, 4, 1))
                latent_phi_w = latent_phi_w.squeeze(-1)

            if self.w_model == 'distr':
                w_dist = self.W(goal)
                w_out = w_dist.mean
                latent_phi_w = (self.phi(obs) * w_out).mean(1, keepdim=True)
            elif self.w_model == 'vector':
                latent_phi_w = (self.phi(obs) * self.W).mean(1, keepdim=True)

        elif not (self.load_w) and self.load_phi:
            phi = torch.ones_like(self.W)
            phi[:, 1:3] = obs[:, 0:2]
            phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            self.W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            self.W[:, 1:3] = 2 * goal
            self.W[:, 3:] = -torch.ones_like(self.W[:, 3:])

            obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
            obs_wt_goal[:, 0:2] = obs[:, 0:2]
            obs_wt_goal[:, 2:] = obs[:, 4:]

            latent_phi_w = torch.bmm(
                self.phi(obs_wt_goal).view(self.batch_size, 1, 4),
                self.W.view(self.batch_size, 4, 1))
            latent_phi_w = latent_phi_w.squeeze(-1)

        mean_squared_error = MeanSquaredError()
        rew_pred = latent_phi_w.detach().cpu()
        mse_error = mean_squared_error(rew_pred, reward.cpu())
        writer.add_scalar('eval_phi_w/mse', mse_error, step)

        # if self.load_w and self.load_phi:
        #     w_pred = self.W(goal).detach().cpu()
        #     mse_w_gt = mean_squared_error(w_pred, W.cpu())
        #     writer.add_scalar('train_w_GT/mse', mse_w_gt, step)

        #     if self.goal_mode == 'multi_goal':
        #         obs_action = torch.cat([obs, action, next_obs], dim=-1)
        #         if self.phi_model == 'mlp':
        #             phi_pred = self.phi(obs_action).detach().cpu()
        #         elif self.phi_model == 'vae':
        #             phi_pred, _, _ = self.phi(obs_action)
        #             phi_pred = phi_pred.detach().cpu()
        #     else:
        #         obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
        #         obs_wt_goal[:, 0:2] = obs[:, 0:2]
        #         obs_wt_goal[:, 2:] = obs[:, 4:]
        #         phi_pred = self.phi(obs_wt_goal).detach().cpu()

        #     mse_w_gt = mean_squared_error(phi_pred, phi.cpu())
        #     writer.add_scalar('train_phi_GT/mse', mse_w_gt, step)

        # elif self.load_w and not (self.load_phi):
        #     w_pred = self.W(goal).detach().cpu()
        #     mse_w_gt = mean_squared_error(w_pred, W.cpu())
        #     writer.add_scalar('train_w_GT/mse', mse_w_gt, step)

        # elif not (self.load_w) and self.load_phi:
        #     obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
        #     obs_wt_goal[:, 0:2] = obs[:, 0:2]
        #     obs_wt_goal[:, 2:] = obs[:, 4:]

        #     phi_pred = self.phi(obs_wt_goal).detach().cpu()
        #     mse_w_gt = mean_squared_error(phi_pred, phi.cpu())
        #     writer.add_scalar('train_phi_GT/mse', mse_w_gt, step)

        if self.load_phi:
            # if self.goal_mode == 'multi_goal':
            obs_action = torch.cat([obs, action, next_obs], dim=-1)
            if self.phi_model == 'mlp':
                phi_latent = self.phi(obs_action)
            elif self.phi_model == 'vae':
                phi_latent, _, _ = self.phi(obs_action)
            target_Q = phi_latent + (not_done * self.discount * target_V)
            # else:
            #     obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
            #     obs_wt_goal[:, 0:2] = obs[:, 0:2]
            #     obs_wt_goal[:, 2:] = obs[:, 4:]
            #     target_Q = self.phi(obs_wt_goal) + (not_done * self.discount *
            #                                         target_V)
        else:
            target_Q = self.phi + (not_done * self.discount * target_V)

        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action, goal, env_id)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def evaluate_phi_w_approx(self, obs, action, next_obs, goal):

        obs_tens = torch.as_tensor(obs, device=self.device).float()
        goal_tens = torch.as_tensor(goal, device=self.device).float()
        current_obs_tens = torch.as_tensor(next_obs,
                                           device=self.device).float()
        act_tens = torch.as_tensor(action, device=self.device).float()

        if self.load_w and self.load_phi:
            # if self.goal_mode == 'multi_goal':
            obs_action = torch.cat([obs_tens, act_tens, current_obs_tens],
                                   dim=0)
            obs_action = obs_action.unsqueeze(0)
            goal_tens = goal_tens.unsqueeze(0)

            if self.phi_model == 'mlp':
                phi_latent = self.phi(obs_action)
            elif self.phi_model == 'vae':
                phi_latent, _, _ = self.phi(obs_action)

            latent_phi_w = torch.bmm(
                phi_latent.view(obs_action.shape[0], 1,
                                self.latent_state_features),
                self.W(goal_tens).view(obs_action.shape[0],
                                       self.latent_state_features, 1))
            latent_phi_w = latent_phi_w.squeeze(0).squeeze(0).cpu().numpy()[0]

        return latent_phi_w

    def update_actor_and_alpha(self, obs, goal, env_id, logger, step):
        if self.append_env_id:
            dist = self.actor(obs, goal, env_id)
        else:
            dist = self.actor(obs, goal)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, action, goal, env_id)

        if self.load_w:
            actor_Q1 = torch.bmm(
                actor_Q1.view(1024, 1, self.latent_state_features),
                self.W(goal).detach().view(1024, self.latent_state_features,
                                           1)).squeeze(-1)
            actor_Q2 = torch.bmm(
                actor_Q2.view(1024, 1, self.latent_state_features),
                self.W(goal).detach().view(1024, self.latent_state_features,
                                           1)).squeeze(-1)

        else:
            actor_Q1 = torch.bmm(actor_Q1.view(1024, 1, 4),
                                 self.W.view(1024, 4, 1)).squeeze(-1)
            actor_Q2 = torch.bmm(actor_Q2.view(1024, 1, 4),
                                 self.W.view(1024, 4, 1)).squeeze(-1)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train_alpha/loss', alpha_loss, step)
            logger.log('train_alpha/value', self.alpha, step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update(self, replay_buffer, logger, writer, step):
        obs, action, reward, next_obs, not_done, not_done_no_max, goal, env_id = replay_buffer.sample(
            self.batch_size)

        logger.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done_no_max,
                           goal, env_id, logger, writer, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, goal, env_id, logger, step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)

    def save(self, model_dir, step):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(self.actor.state_dict(),
                   '%s/actor_%s.pt' % (model_dir, step))
        torch.save(self.critic.state_dict(),
                   '%s/critic_%s.pt' % (model_dir, step))
        torch.save(self.actor_optimizer.state_dict(),
                   '%s/actor_optim_%s.pt' % (model_dir, step))
        torch.save(self.critic_optimizer.state_dict(),
                   '%s/critic_optim_%s.pt' % (model_dir, step))

        torch.save(self.log_alpha, '%s/log_alpha_%s.pt' % (model_dir, step))
        torch.save(self.log_alpha_optimizer.state_dict(),
                   '%s/alpha_optim_%s.pt' % (model_dir, step))

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step)))
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step)))
        self.actor_optimizer.load_state_dict(
            torch.load('%s/actor_optim_%s.pt' % (model_dir, step)))
        self.critic_optimizer.load_state_dict(
            torch.load('%s/critic_optim_%s.pt' % (model_dir, step)))

        self.log_alpha = torch.load('%s/log_alpha_%s.pt' % (model_dir, step))
        self.log_alpha_optimizer.load_state_dict(
            torch.load('%s/alpha_optim_%s.pt' % (model_dir, step)))

    def load_phi_w(self, model_dir, step):
        if self.load_w and self.load_phi:
            self.W.load_state_dict(torch.load('%s/w_%s.pt' %
                                              (model_dir, step)))
            self.phi.load_state_dict(
                torch.load('%s/phi_%s.pt' % (model_dir, step)))

        elif self.load_w and not (self.load_phi):
            if self.w_model == 'distr' or self.w_model == 'mlp':
                self.W.load_state_dict(
                    torch.load('%s/w_%s.pt' % (model_dir, step)))
            else:
                self.W = torch.load('%s/w_%s.pt' % (model_dir, step))

        elif not (self.load_w) and self.load_phi:
            self.phi.load_state_dict(
                torch.load('%s/phi_%s.pt' % (model_dir, step)))
