# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the USPCSACPID (USPC version of SACPID) algorithm."""


import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_

from copy import deepcopy

from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.sac_pid import SACPID
from omnisafe.common.pid_lagrange import PIDLagrangian
from omnisafe.models.base import Critic
from omnisafe.models.critic.critic_builder import CriticBuilder

from omnisafe.models.actor_critic.ssn_constraint_actor_q_critic import SSNConstraintActorQCritic


@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class USPCSACPID(SACPID):
    """The USPCSACPID (USPC version of SACPID) algorithm.

    References:
        - Title: Responsive Safety in Reinforcement Learning by PID Lagrangian Methods
        - Authors: Adam Stooke, Joshua Achiam, Pieter Abbeel.
        - URL: `SACPID <https://arxiv.org/abs/2007.03964>`_
    """

    def _init(self) -> None:
        super()._init()

        self.ensemble_size = self._cfgs.USPC_cfgs.USPC_ensemble_size

        self.ssn_local_samples = self._cfgs.USPC_cfgs.ssn_local_samples
        self.ssn_global_samples = self._cfgs.USPC_cfgs.ssn_global_samples
        self.ssn_do_self_witness = self._cfgs.USPC_cfgs.ssn_do_self_witness
        self.ssn_lipschitz = self._cfgs.USPC_cfgs.ssn_lipschitz
        self.ssn_beta = self._cfgs.USPC_cfgs.ssn_beta
        self.ssn_cov_scale = self._cfgs.USPC_cfgs.ssn_cov_scale

        self.act_min = torch.as_tensor(
            self._env.action_space.low,
            device=self._device,
            dtype=torch.float32,
        )
        self.act_max = torch.as_tensor(
            self._env.action_space.high,
            device=self._device,
            dtype=torch.float32,
        )

    def _init_model(self) -> None:
        self._cfgs.model_cfgs.critic['num_critics'] = 2
        self._actor_critic = SSNConstraintActorQCritic(
            obs_space=self._env.observation_space,
            act_space=self._env.action_space,
            model_cfgs=self._cfgs.model_cfgs,
            USPC_cfgs=self._cfgs.USPC_cfgs,
            epochs=self._epochs,
        ).to(self._device)

        self._ssn: Critic = (
            CriticBuilder(
                obs_space=self._env.observation_space,
                act_space=self._env.action_space,
                hidden_sizes=self._cfgs.model_cfgs.critic.hidden_sizes,
                activation=self._cfgs.model_cfgs.critic.activation,
                weight_initialization_mode=self._cfgs.model_cfgs.weight_initialization_mode,
                num_critics=self._cfgs.USPC_cfgs.USPC_ensemble_size,
                use_obs_encoder=False,
            )
            .build_critic('q')
            .to(self._device)
        )
        self._target_ssn: Critic = deepcopy(self._ssn)
        for param in self._target_ssn.parameters():
            param.requires_grad = False
        if self._cfgs.model_cfgs.critic.lr is not None:
            self._ssn_optimizer: optim.Optimizer
            self._ssn_optimizer = optim.Adam(
                self._ssn.parameters(),
                lr=self._cfgs.model_cfgs.critic.lr,
            )

    def _init_log(self) -> None:
        """Log the SSN/USPC related info"""
        super()._init_log()
        # USPC stuff
        self._logger.register_key("SSN/SSN_loss")
        self._logger.register_key("SSN/SSN_value")
        self._logger.register_key("SSN/witness_fraction")
        self._logger.register_key("SSN/mean_ucb_anchor")
        self._logger.register_key("SSN/ensemble_std_mean")
        self._logger.register_key("SSN/selected_anchor_dist_to_policy_mean")

    def _update(self) -> None:
        """Update actor, critic ensemble nnd SSN.

        -  Get the ``data`` from buffer

        .. note::

            +----------+---------------------------------------+
            | obs      | ``observaion`` stored in buffer.      |
            +==========+=======================================+
            | act      | ``action`` stored in buffer.          |
            +----------+---------------------------------------+
            | reward   | ``reward`` stored in buffer.          |
            +----------+---------------------------------------+
            | cost     | ``cost`` stored in buffer.            |
            +----------+---------------------------------------+
            | next_obs | ``next observaion`` stored in buffer. |
            +----------+---------------------------------------+
            | done     | ``terminated`` stored in buffer.      |
            +----------+---------------------------------------+

        -  Update value net by :meth:`_update_reward_critic`.
        -  Update cost net by :meth:`_update_cost_critic`.
        -  Update policy net by :meth:`_update_actor`.

        The basic process of each update is as follows:

        #. Get the mini-batch data from buffer.
        #. Get the loss of network.
        #. Update the network by loss.
        #. Repeat steps 2, 3 until the ``update_iters`` times.
        """
        for _ in range(self._cfgs.algo_cfgs.update_iters):
            data = self._buf.sample_batch()
            self._update_count += 1
            obs, act, reward, cost, done, next_obs = (
                data['obs'],
                data['act'],
                data['reward'],
                data['cost'],
                data['done'],
                data['next_obs'],
            )

            self._update_reward_critic(obs, act, reward, done, next_obs)

            if self._cfgs.algo_cfgs.use_cost:
                self._update_cost_critic(obs, act, cost, done, next_obs)

            self._update_ssn(obs, act)  # safe set net

            if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0:
                self._update_actor(obs)
                self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)

        # lagrange multiplier update (from SACPID)
        Jc = self._logger.get_stats('Metrics/EpCost')[0]
        if self._epoch > self._cfgs.algo_cfgs.warmup_epochs:
            self._lagrange.pid_update(Jc)
        self._logger.store(
            {
                'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier,
            },
        )

    def _update_ssn(self, obs: torch.Tensor, acts: torch.Tensor) -> None:

        B = obs.shape[0]
        act_dim = acts.shape[1]
        num_anchors = (
            self.ssn_local_samples
            + self.ssn_global_samples
            + (1 if self.ssn_do_self_witness else 0)
        )

        # create anchor set
        with torch.no_grad():
            # local anchors
            old_dist = self._actor_critic.target_actor._distribution(obs)
            old_mean, old_std = old_dist.mean, old_dist.stddev
            local_dist = torch.distributions.Normal(old_mean, old_std * self.ssn_cov_scale)
            anc_locals = local_dist.sample((self.ssn_local_samples,)).permute(1, 0, 2)
            anc_locals.clamp_(self.act_min, self.act_max)

            # global anchors
            global_dist = torch.distributions.Uniform(self.act_min, self.act_max)
            anc_globals = global_dist.sample((B, self.ssn_global_samples))

            # concatenate
            if self.ssn_do_self_witness:
                anchors = torch.cat([acts.unsqueeze(1), anc_locals, anc_globals], dim=1)
            else:
                anchors = torch.cat([anc_locals, anc_globals], dim=1)

            anchors = anchors.to(dtype=acts.dtype)

            # compute Q(s, a) and its mean/std
            obs_rep = obs.repeat_interleave(num_anchors, dim=0)
            anchors_flat = anchors.reshape(-1, act_dim)
            q_values = self._actor_critic.cost_critic(obs_rep, anchors_flat)
            q_values = torch.stack(q_values, dim=1)
            q_mean = q_values.mean(dim=1)
            q_std = q_values.std(dim=1)
            q_ucb = (q_mean + self.ssn_beta * q_std).reshape(B, -1)

            # compute y(s, a') = min_(a: a safe) ucb(Q(s,a)) + L * d(a, a')
            dists = torch.linalg.vector_norm(anchors - acts.unsqueeze(1), ord=2, dim=-1)
            obj_vals = q_ucb + self.ssn_lipschitz * dists

            # set unsafe actions with infinite objective value
            ssn_target_preds = self._target_ssn(obs_rep, anchors_flat)[0].reshape(B, -1)
            unsafe_mask = ssn_target_preds > self._cfgs.lagrange_cfgs.cost_limit
            obj_vals_masked = obj_vals.masked_fill(unsafe_mask, float('inf'))
            y_sa, inds = torch.min(obj_vals_masked, dim=1)

            # if all anchors are unsafe for a given batch element, fall back to the unmasked minimum
            all_unsafe = unsafe_mask.all(dim=1)
            y_sa = torch.where(all_unsafe, obj_vals.min(dim=1).values, y_sa)

        # update ssn
        ssn_pred = self._ssn(obs, acts)[0]
        ssn_loss = torch.nn.functional.mse_loss(
            ssn_pred,
            y_sa,
        )
        self._ssn_optimizer.zero_grad()
        ssn_loss.backward()
        self._ssn_optimizer.step()

        # do polyak averaging for target ssn
        tau = self._cfgs.algo_cfgs.polyak
        for target_param, param in zip(self._target_ssn.parameters(), self._ssn.parameters()):
            target_param.data.mul_(1 - tau).add_(param.data, alpha=tau)

        # log stuff
        self_witness_frac = (inds == 0).float().mean().item() if self.ssn_do_self_witness else 0.0
        ssn_mean_pred = ssn_pred.mean().item()
        anchor_q_means = q_mean.mean().item()
        anchor_q_std = q_std.mean().item()

        selected_anchor_dists = dists[torch.arange(B, device=dists.device), inds].mean().item()

        self._logger.store(
            {
                "SSN/SSN_loss": ssn_loss.item(),
                "SSN/SSN_value": ssn_mean_pred,
                "SSN/witness_fraction": self_witness_frac,
                "SSN/mean_ucb_anchor": anchor_q_means,
                "SSN/ensemble_std_mean": anchor_q_std,
                "SSN/selected_anchor_dist_to_policy_mean": selected_anchor_dists,
            },
        )

    # override
    def _update_cost_critic(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        cost: torch.Tensor,
        done: torch.Tensor,
        next_obs: torch.Tensor,
    ) -> None:
        """Update cost critic.

        - Get the TD loss of cost critic.
        - Update critic network by loss.
        - Log useful information.

        Args:
            obs (torch.Tensor): The ``observation`` sampled from buffer.
            action (torch.Tensor): The ``action`` sampled from buffer.
            cost (torch.Tensor): The ``cost`` sampled from buffer.
            done (torch.Tensor): The ``terminated`` sampled from buffer.
            next_obs (torch.Tensor): The ``next observation`` sampled from buffer.
        """

        with torch.no_grad():
            next_action = self._actor_critic.actor.predict(next_obs, deterministic=True)
        q_value_c = self._actor_critic.cost_critic(obs, action)
        loss = obs.new_zeros(())

        for i in range(self.ensemble_size):
            with torch.no_grad():
                next_q_value_c_i = self._actor_critic.target_cost_critic(next_obs, next_action)[i]
                target_q_value_c_i = (
                    cost + self._cfgs.algo_cfgs.gamma * (1 - done) * next_q_value_c_i
                )
            q_value_c_i = q_value_c[i]
            loss += nn.functional.mse_loss(q_value_c_i, target_q_value_c_i)

        if self._cfgs.algo_cfgs.use_critic_norm:
            for param in self._actor_critic.cost_critic.parameters():
                loss += param.pow(2).sum() * self._cfgs.algo_cfgs.critic_norm_coeff

        self._actor_critic.cost_critic_optimizer.zero_grad()
        loss.backward()

        if self._cfgs.algo_cfgs.max_grad_norm:
            clip_grad_norm_(
                self._actor_critic.cost_critic.parameters(),
                self._cfgs.algo_cfgs.max_grad_norm,
            )
        self._actor_critic.cost_critic_optimizer.step()

        self._logger.store(
            {
                'Loss/Loss_cost_critic': loss.mean().item(),
                'Value/cost_critic': torch.stack(q_value_c).mean().item(),
            },
        )

    def _loss_pi(
        self,
        obs: torch.Tensor,
    ) -> torch.Tensor:
        r"""Computing ``pi/actor`` loss.

        The loss function in SACPID is defined as:

        .. math::

            L = -Q^V (s, \pi (s)) + \lambda Q^C (s, \pi (s))

        where :math:`Q^V` is the min value of two reward critic networks outputs, :math:`Q^C` is the
        value of cost critic network, and :math:`\pi` is the policy network.

        Args:
            obs (torch.Tensor): The ``observation`` sampled from buffer.

        Returns:
            The loss of pi/actor.
        """
        action = self._actor_critic.actor.predict(obs, deterministic=False)
        log_prob = self._actor_critic.actor.log_prob(action)
        loss_q_r_1, loss_q_r_2 = self._actor_critic.reward_critic(obs, action)
        loss_r = self._alpha * log_prob - torch.min(loss_q_r_1, loss_q_r_2)
        # loss_q_c = self._actor_critic.cost_critic(obs, action)[0]
        loss_ssn = self._ssn(obs, action)[0]
        loss_c = self._lagrange.lagrangian_multiplier * loss_ssn

        return (loss_r + loss_c).mean() / (1 + self._lagrange.lagrangian_multiplier)

    def _log_when_not_update(self) -> None:
        super()._log_when_not_update()
        self._logger.store(
            {
                'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier,
            },
        )
