# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os

from agent import Agent
from agent.actor import DiagGaussianW

import utils

import hydra
from torchmetrics import MeanSquaredError
from torch import linalg as LA


class SACAgent(Agent):
    """SAC algorithm."""
    def __init__(self, obs_dim, action_dim, goal_dim, env_id_dim, action_range,
                 device, critic_cfg, actor_cfg, discount, init_temperature,
                 alpha_lr, alpha_betas, actor_lr, actor_betas,
                 actor_update_frequency, critic_lr, critic_betas, critic_tau,
                 critic_target_update_frequency, batch_size,
                 learnable_temperature, load_w, load_phi, rep_model) -> None:
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature

        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(
            self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -action_dim

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr,
                                                betas=actor_betas)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr,
                                                 betas=critic_betas)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr,
                                                    betas=alpha_betas)

        self.load_w = load_w
        self.load_phi = load_phi
        self.rep_model = rep_model

        self.state_features = 14
        # self.state_features = 10

        if self.rep_model == 'mlp':
            self.phi_w_joint = utils.PhiWJointMLP(
                input_dim=self.state_features,
                hidden_dim_1=128,
                hidden_dim_2=256,
                output_dim=1).to(self.device)

        elif self.rep_model == 'vae':
            self.phi_w_joint = utils.PhiWJointVAE(
                state_dim=self.state_features,
                hidden_dim_1=750,
                hidden_dim_2=750,
                latent_dim=32,
                output_dim=1).to(self.device)

        for param in self.phi_w_joint.parameters():
            param.requires_grad = False

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        self.phi_w_joint.train(False)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, goal, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        goal = torch.FloatTensor(goal).to(self.device)
        obs = obs.unsqueeze(0)
        goal = goal.unsqueeze(0)
        dist = self.actor(obs, goal)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, obs, action, reward, next_obs, not_done, goal,
                      env_id, logger, writer, step):
        dist = self.actor(next_obs, goal)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action, goal)

        target_q1_sum = torch.sum(target_Q1, 1).unsqueeze(-1)
        target_q2_sum = torch.sum(target_Q2, 1).unsqueeze(-1)
        target_Q_min = torch.min(target_q1_sum, target_q2_sum)
        mask_q1 = torch.eq(target_Q_min, target_q1_sum)
        mask_q2 = torch.eq(target_Q_min, target_q2_sum)
        target_q1_min = (target_Q1 * mask_q1) + (target_Q2 * mask_q2)
        target_V = target_q1_min - self.alpha.detach() * log_prob

        obs_action = torch.cat([obs, action, next_obs], dim=-1)
        if self.rep_model == 'mlp':
            latent_phi_w = self.phi_w_joint(obs_action, env_id)

        elif self.rep_model == 'vae':
            latent_phi_w, mean, std = self.phi_w_joint(obs_action, env_id)

        mean_squared_error = MeanSquaredError()
        rew_pred = latent_phi_w.detach().cpu()
        mse_error = mean_squared_error(rew_pred, reward.cpu())
        writer.add_scalar('eval_phi_w/mse', mse_error, step)

        if self.rep_model == 'mlp':
            target_Q = self.phi_w_joint.latent_phi(obs_action) + (
                not_done * self.discount * target_V)
        elif self.rep_model == 'vae':
            z = mean + std * torch.randn_like(std)
            target_Q = self.phi_w_joint.decode(
                obs_action, z) + (not_done * self.discount * target_V)

        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action, goal)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def update_actor_and_alpha(self, obs, action, goal, env_id, logger, step):
        dist = self.actor(obs, goal)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, action, goal)

        # obs_action = torch.cat([obs, action], dim=-1)

        # Q: Are the dimensions correct for this dot product?
        # We take the d-dimensional psi and multiply by W weights
        # repeated along the batch size
        w1_weights = self.phi_w_joint.W1.weight.data.detach()
        w1_weights = w1_weights.unsqueeze(1).repeat(1, 1024, 1)

        w2_weights = self.phi_w_joint.W2.weight.data.detach()
        w2_weights = w2_weights.unsqueeze(1).repeat(1, 1024, 1)

        w3_weights = self.phi_w_joint.W3.weight.data.detach()
        w3_weights = w3_weights.unsqueeze(1).repeat(1, 1024, 1)

        w4_weights = self.phi_w_joint.W4.weight.data.detach()
        w4_weights = w4_weights.unsqueeze(1).repeat(1, 1024, 1)

        W1_mask = torch.logical_or(torch.eq(env_id, 0), torch.eq(env_id, 4))
        W2_mask = torch.logical_or(torch.eq(env_id, 1), torch.eq(env_id, 5))
        W3_mask = torch.eq(env_id, 2)
        W4_mask = torch.eq(env_id, 3)

        w1_weights *= W1_mask
        w2_weights *= W2_mask
        w3_weights *= W3_mask
        w4_weights *= W4_mask

        w_weights = w1_weights + w2_weights + w3_weights + w4_weights

        actor_Q1 = torch.bmm(actor_Q1.view(1024, 1, self.state_features),
                             w_weights.view(1024, self.state_features,
                                            1)).squeeze(-1)
        actor_Q2 = torch.bmm(actor_Q2.view(1024, 1, self.state_features),
                             w_weights.view(1024, self.state_features,
                                            1)).squeeze(-1)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train_alpha/loss', alpha_loss, step)
            logger.log('train_alpha/value', self.alpha, step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update(self, replay_buffer, logger, writer, step):
        obs, action, reward, next_obs, not_done, not_done_no_max, goal, env_id = replay_buffer.sample(
            self.batch_size)

        logger.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done_no_max,
                           goal, env_id, logger, writer, step)

        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, action, goal, env_id, logger,
                                        step)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target,
                                     self.critic_tau)

    def save(self, model_dir, step):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(self.actor.state_dict(),
                   '%s/actor_%s.pt' % (model_dir, step))
        torch.save(self.critic.state_dict(),
                   '%s/critic_%s.pt' % (model_dir, step))
        torch.save(self.actor_optimizer.state_dict(),
                   '%s/actor_optim_%s.pt' % (model_dir, step))
        torch.save(self.critic_optimizer.state_dict(),
                   '%s/critic_optim_%s.pt' % (model_dir, step))

        torch.save(self.log_alpha, '%s/log_alpha_%s.pt' % (model_dir, step))
        torch.save(self.log_alpha_optimizer.state_dict(),
                   '%s/alpha_optim_%s.pt' % (model_dir, step))

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step)))
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step)))
        self.actor_optimizer.load_state_dict(
            torch.load('%s/actor_optim_%s.pt' % (model_dir, step)))
        self.critic_optimizer.load_state_dict(
            torch.load('%s/critic_optim_%s.pt' % (model_dir, step)))

        self.log_alpha = torch.load('%s/log_alpha_%s.pt' % (model_dir, step))
        self.log_alpha_optimizer.load_state_dict(
            torch.load('%s/alpha_optim_%s.pt' % (model_dir, step)))

    def load_phi_w(self, model_dir, step):

        self.phi_w_joint.load_state_dict(
            torch.load('%s/phi_w_joint%s.pt' % (model_dir, step)))
