from argparse import Action
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from copy import deepcopy

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from torch import nn

from fsrl.policy.lagrangian_base import LagrangianPolicy
from fsrl.utils import BaseLogger
from fsrl.utils.net.common import ActorCritic
from fsrl.utils.net.continuous import SingleCritic


class USPCPPOLagrangian(LagrangianPolicy):
    """Implementation of the Proximal Policy Optimization (PPO) with PID Lagrangian and USPC.

    More details, please refer to https://arxiv.org/abs/1707.06347 (PPO) and
    https://arxiv.org/abs/2007.03964 (PID Lagrangian).

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~fsrl.policy.BasePolicy`. (s -> logits)
    :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
    :param torch.optim.Optimizer optim: the optimizer for the actor and critic networks.
    :param Type[torch.distributions.Distribution] dist_fn: the distribution class for
        action sampling.
    :param BaseLogger logger: the logger instance for logging training information.
        Default is BaseLogger().
    :param float target_kl: the target KL divergence for the PPO update. Default is 0.02.
    :param float vf_coef: the value function coefficient for the loss function. Default
        is 0.25.
    :param Optional[float] max_grad_norm: the maximum gradient norm for gradient clipping
        (None for no clipping). Default is None.
    :param float gae_lambda: the Generalized Advantage Estimation (GAE) parameter.
        Default is 0.95.
    :param float eps_clip: the PPO clipping parameter for the policy update. Default is
        0.2.
    :param Optional[float] dual_clip: the PPO dual clipping parameter (None for no dual
        clipping). Default is None.
    :param bool value_clip: whether to clip the value function update. Default is False.
    :param bool advantage_normalization: whether to normalize the advantages. Default is
        True.
    :param bool recompute_advantage: whether to recompute the advantages during the
        optimization process. Default is False.
    :param bool use_lagrangian: whether to use the Lagrangian constraint optimization.
        Default is True.
    :param List lagrangian_pid: the PID coefficients for the Lagrangian constraint
        optimization. Default is [0.05, 0.0005, 0.1].
    :param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian
        optimization. Default is np.inf.
    :param bool rescaling: whether use the rescaling trick for Lagrangian multiplier, see
        Alg. 1 in http://proceedings.mlr.press/v119/stooke20a/stooke20a.pdf
    :param float gamma: the discount factor for future rewards. Default is 0.99.
    :param int max_batchsize: the maximum size of the batch when computing GAE, depends
        on the size of available memory and the memory cost of the model; should be as
        large as possible within the memory constraint. Default to 99999.
    :param bool reward_normalization: whether to normalize the rewards. Default is False.
    :param bool deterministic_eval: whether to use deterministic actions during
        evaluation. Default is True.
    :param bool action_scaling: whether to scale actions based on the action space.
        Default is True.
    :param str action_bound_method: the method used to handle out-of-bound actions.
        Default is "clip".
    :param Optional[gym.Space] observation_space: the observation space of the
        environment. Default is None.
    :param Optional[gym.Space] action_space: the action space of the environment. Default
        is None.
    :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: the learning rate
        scheduler. Default is None.

    .. seealso::

       Please refer to :class:`~fsrl.policy.BasePolicy` and
       :class:`~fsrl.policy.LagrangianPolicy` for more detailed \
           hyperparameter explanations and usage.
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: Union[nn.Module, List[nn.Module]],
        safeset_net: nn.Module,
        expanderset_net: nn.Module,
        optim: torch.optim.Optimizer,
        safeset_optim: torch.optim.Optimizer,
        expanderset_optim: torch.optim.Optimizer,
        dist_fn: Type[torch.distributions.Distribution],
        logger: BaseLogger = BaseLogger(),
        # PPO specific argumentsß
        target_kl: float = 0.02,
        vf_coef: float = 0.25,
        max_grad_norm: Optional[float] = None,
        gae_lambda: float = 0.95,
        eps_clip: float = 0.2,
        dual_clip: Optional[float] = None,
        value_clip: bool = False,
        advantage_normalization: bool = True,
        recompute_advantage: bool = False,
        # Lagrangian specific arguments
        use_lagrangian: bool = True,
        lagrangian_pid: Tuple = (0.05, 0.0005, 0.1),
        cost_limit: Union[List, float] = np.inf,
        rescaling: bool = True,
        # USPC
        USPC_L: float = 10.0,
        USPC_beta: float = 2.0,
        USPC_cov_scale: float = 1.0,
        USPC_sample_act_num: int = 64,
        # Base policy common arguments
        gamma: float = 0.99,
        max_batchsize: int = 99999,
        reward_normalization: bool = False,
        deterministic_eval: bool = True,
        action_scaling: bool = True,
        action_bound_method: str = "clip",
        observation_space: Optional[gym.Space] = None,
        action_space: Optional[gym.Space] = None,
        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
    ) -> None:
        super().__init__(
            actor,
            critics,
            dist_fn,
            logger,
            use_lagrangian,
            lagrangian_pid,
            cost_limit,
            rescaling,
            gamma,
            max_batchsize,
            reward_normalization,
            deterministic_eval,
            action_scaling,
            action_bound_method,
            observation_space,
            action_space,
            lr_scheduler,
        )
        self.optim = optim

        self.safeset_net = safeset_net
        self.expanderset_net = expanderset_net
        self.safeset_net_old = deepcopy(self.safeset_net)
        self.expanderset_net_old = deepcopy(self.expanderset_net)
        self.safeset_net_old.eval()
        self.expanderset_net_old.eval()
        self.safeset_optim = safeset_optim
        self.expanderset_optim = expanderset_optim

        self._lambda = gae_lambda
        self._weight_vf = vf_coef
        self._grad_norm = max_grad_norm
        self._target_kl = target_kl
        self._eps_clip = eps_clip
        assert (
            dual_clip is None or dual_clip > 1.0
        ), "Dual-clip PPO parameter should greater than 1.0."
        self._dual_clip = dual_clip
        self._value_clip = value_clip
        if not self._rew_norm:
            assert (
                not self._value_clip
            ), "value clip is available only when `reward_normalization` is True"
        self._norm_adv = advantage_normalization
        self._recompute_adv = recompute_advantage
        self._actor_critic: ActorCritic

        # USPC init
        self._USPC_sample_act_num = USPC_sample_act_num
        self._USPC_cov_scale = USPC_cov_scale
        self._USPC_L = USPC_L
        self._USPC_beta = USPC_beta

        self.device = next(self.actor.parameters()).device
        self.dtype = next(self.actor.parameters()).dtype

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        if self._recompute_adv:
            self._buffer, self._indices = buffer, indices
        batch = self.compute_gae_returns(batch, buffer, indices, self._lambda)
        batch.act = to_torch_as(batch.act, batch.values[..., 0])
        old_log_prob = []
        with torch.no_grad():
            for minibatch in batch.split(
                self._max_batchsize, shuffle=False, merge_last=True
            ):
                old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
        batch.logp_old = torch.cat(old_log_prob, dim=0)
        return batch

    def critics_loss(self, minibatch):
        critic_losses = 0
        stats = {}
        B, ds = minibatch.obs.shape
        B, da = minibatch.act.shape

        for i, critic in enumerate(self.critics):
            value = critic.predict(
                minibatch.obs.reshape(-1, ds), minibatch.act.reshape(-1, da)
            )[0].flatten()
            ret = minibatch.rets[..., i]
            if self._value_clip:
                value_target = minibatch.values[..., i]
                v_clip = value_target + (value - value_target).clamp(
                    -self._eps_clip, self._eps_clip
                )
                vf1 = (ret - value).pow(2)
                vf2 = (ret - v_clip).pow(2)
                vf_loss = torch.max(vf1, vf2).mean()
            else:
                vf_loss = (ret - value).pow(2).mean()
            critic_losses += vf_loss

            stats["loss/vf" + str(i)] = vf_loss.item()
        stats["loss/vf_total"] = critic_losses.item()
        return critic_losses, stats

    def safeset_net_loss(
        self,
        batch: Batch,
        sample_act: torch.Tensor,
        safeset_net: nn.Module,
        optimizer: torch.optim.Optimizer,
    ) -> Dict[str, float]:
        stats_safeset = {}

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=self.dtype
        )  # (B, ds)
        act = torch.as_tensor(
            batch.act, device=self.device, dtype=self.dtype
        )  # (B, da)

        K_pi, B, da = sample_act.shape
        K_loc = K_pi // 2
        K_glob = K_pi // 2
        K_all = K_pi + K_loc + K_glob
        ds = obs.shape[-1]
        act_low = torch.as_tensor(
            self.action_space.low, device=obs.device, dtype=self.dtype
        )
        act_high = torch.as_tensor(
            self.action_space.high, device=obs.device, dtype=self.dtype
        )
        with torch.no_grad():
            dist = self.forward(batch).dist
            mu_old = dist.mean  # (B, da)
            std_old = dist.stddev  # (B, da)

        obs_flat = (
            obs.unsqueeze(0).expand(K_pi, -1, -1).reshape(-1, ds)
        )  # (K_pi * B, ds)
        sample_act_flat = sample_act.reshape(-1, da)  # (K_pi * B, da)

        print(mu_old.shape, std_old.shape, sample_act.shape)
        local_samples = mu_old.unsqueeze(0) + (
            self._USPC_cov_scale * std_old
        ) * torch.randn(K_loc, B, da, device=obs.device, dtype=self.dtype)
        local_samples = local_samples.clamp(min=act_low, max=act_high)  # (K_loc, B, da)

        global_samples = (
            torch.rand(K_glob, B, da, device=obs.device, dtype=self.dtype)
            * (act_high - act_low)
            + act_low
        )
        global_samples = global_samples.clamp(
            min=act_low, max=act_high
        )  # (K_glob, B, da)

        nbrs = torch.cat(
            [sample_act, local_samples, global_samples], dim=0
        )  # (K_all, B, da)

        obs_flat_all = (
            obs.unsqueeze(0).expand(K_all, -1, -1).reshape(-1, ds)
        )  # (K_all * B, ds)
        with torch.no_grad():
            self.critics[1].eval()
            _, q_heads = self.critics[1].predict(
                obs_flat_all, nbrs.reshape(-1, da)
            )  # (K_all * B, n_heads)
            self.critics[1].train()
        q_vals = torch.stack(
            [q.squeeze(-1) for q in q_heads], dim=1
        )  # (K_all * B, n_heads)
        q_mean = q_vals.mean(dim=1)
        q_std = q_vals.std(dim=1)
        q_ucb = (q_mean + self._USPC_beta * q_std).reshape(K_all, B)  # (K_all, B)

        nbrs_BA = nbrs.permute(1, 0, 2).contiguous()  # (B, K_all, da)
        cands_BK = sample_act.permute(1, 0, 2).contiguous()  # (B, K_pi,  da)
        dists_BAK = torch.cdist(nbrs_BA, cands_BK)  # (B, K_all, K_pi)
        q_ucb_BA = q_ucb.permute(1, 0).unsqueeze(-1)  # (B, K_all, 1)

        h = getattr(self, "_USPC_threshold", None)
        safeset_net_old = getattr(self, "safeset_net_old", None)

        obj_BAK = q_ucb_BA + self._USPC_L * dists_BAK  # (B, K_all, K_pi)

        if (h is not None) and (safeset_net_old is not None):
            with torch.no_grad():
                f_old_flat = safeset_net_old(obs_flat_all, nbrs.reshape(-1, da))[
                    0
                ].squeeze(
                    -1
                )  # (K_all*B,)
                f_old_BA = f_old_flat.reshape(K_all, B).permute(1, 0)  # (B, K_all)
            safe_mask_BA = f_old_BA <= h  # (B, K_all)

            masked_obj_BAK = obj_BAK.masked_fill(
                ~safe_mask_BA.unsqueeze(-1), float("inf")
            )
            min_masked_BK = masked_obj_BAK.min(dim=1).values  # (B, K_pi)

            has_safe = safe_mask_BA.any(dim=1, keepdim=True)  # (B,1)
            y_BK = torch.where(
                has_safe, min_masked_BK, obj_BAK.min(dim=1).values
            )  # (B, K_pi)
        else:
            y_BK = obj_BAK.min(dim=1).values  # (B, K_pi)

        y_flat = y_BK.permute(1, 0).reshape(-1)  # (K_pi*B,)

        safe_pred_flat = safeset_net(obs_flat, sample_act_flat)[0].squeeze(
            -1
        )  # (K_pi*B,)
        loss = ((safe_pred_flat - y_flat) ** 2).mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        witness_idx_BK = obj_BAK.argmin(dim=1)  # (B, K_pi)
        self_witness_frac = (witness_idx_BK < K_pi).float().mean().item()

        stats_safeset["loss/safeset"] = loss.item()
        stats_safeset["diag/self_witness_frac"] = self_witness_frac
        stats_safeset["diag/mean_ucb_anchor"] = q_ucb.mean().item()
        stats_safeset["diag/mean_dist"] = dists_BAK.mean().item()

        return stats_safeset

    def policy_loss(self, batch: Batch, dist: Type[torch.distributions.Distribution]):

        log_p = dist.log_prob(batch.act)
        ratio = (log_p - batch.logp_old).exp().float()
        ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
        if self._norm_adv:
            for i in range(self.critics_num):
                adv = batch.advs[..., i]
                mean, std = adv.mean(), adv.std()
                batch.advs[..., i] = (adv - mean) / std  # per-batch norm

        # compute normal ppo loss
        rew_adv = batch.advs[..., 0]
        surr1 = ratio * rew_adv
        surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * rew_adv
        if self._dual_clip:
            clip1 = torch.min(surr1, surr2)
            clip2 = torch.max(clip1, self._dual_clip * rew_adv)
            loss_actor_rew = -torch.where(rew_adv < 0, clip2, clip1).mean()
        else:
            loss_actor_rew = -torch.min(surr1, surr2).mean()

        # compute safety loss
        values = (
            [ratio * batch.advs[..., i] for i in range(1, self.critics_num)]
            if self.use_lagrangian
            else []
        )
        loss_actor_safety, stats_actor = self.safety_loss(values)

        rescaling = stats_actor["loss/rescaling"]
        loss_actor_total = rescaling * (loss_actor_rew + loss_actor_safety)

        # compute approx KL for early stop
        approx_kl = (batch.logp_old - log_p).mean().item()
        stats_actor.update(
            {
                "loss/actor_rew": loss_actor_rew.item(),
                "loss/actor_total": loss_actor_total.item(),
                "loss/kl": approx_kl,
            }
        )
        return loss_actor_total, stats_actor

    def learn(
        self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
    ) -> Dict[str, List[float]]:
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self.compute_gae_returns(
                    batch, self._buffer, self._indices, self._lambda
                )
            iter_counts, approx_kl = 0, 0.0
            for minibatch in batch.split(batch_size, merge_last=True):
                dist = self.forward(minibatch).dist
                loss_actor, stats_actor = self.policy_loss(minibatch, dist)
                approx_kl += stats_actor["loss/kl"]
                iter_counts += 1
                loss_vf, stats_critic = self.critics_loss(minibatch)

                K = self._USPC_sample_act_num
                sample_act: torch.Tensor = dist.sample((K,))  # (K, B, da)

                stats_safeset = self.safeset_net_loss(
                    minibatch, sample_act, self.safeset_net, self.safeset_optim
                )

                loss = loss_actor + self._weight_vf * loss_vf

                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                self.gradient_steps += 1

                ent = dist.entropy().mean()
                self.logger.store(**stats_actor)
                self.logger.store(**stats_critic)
                self.logger.store(**stats_safeset)
                self.logger.store(total=loss.item(), entropy=ent.item(), tab="loss")

            # trick in
            # https://github.com/liuzuxin/robust-safe-rl/blob/main/rsrl/policy/robust_ppo.py # noqa
            approx_kl /= iter_counts + 1e-7
            if approx_kl > 1.5 * self._target_kl:
                # early stop
                self.logger.print(
                    "Early stop at step %d due to reaching max kl." % step
                )
                break

        self.logger.store(gradient_steps=self.gradient_steps, tab="update")
