from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
from tianshou.exploration import BaseNoise
from torch import nn
from torch.distributions import Independent, Normal

from fsrl.policy.lagrangian_base import LagrangianPolicy
from fsrl.utils import BaseLogger
from fsrl.utils.optim_util import LagrangianOptimizer
from fsrl.policy import BasePolicy
from fsrl.policy.base_policy import nstep_return


class USPCSACLagrangian(LagrangianPolicy):
    """Implementation of the Soft Actor-Critic (SAC) with PID Lagrangian.

    More details, please refer to https://arxiv.org/abs/1801.01290 (SAC) and
    https://arxiv.org/abs/2007.03964 (PID Lagrangian).

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~fsrl.policy.BasePolicy`. (s -> logits)
    :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
    :param Optional[torch.optim.Optimizer] actor_optim: the optimizer for the actor
        network.
    :param Optional[torch.optim.Optimizer] critic_optim: the optimizer for the critic
        network(s).
    :param BaseLogger logger: the logger instance for logging training information.
        (default: DummyLogger)
    :param Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] alpha: initial
        temperature for entropy regularization. If a tuple (target_entropy, log_alpha,
        alpha_optim) is provided, then alpha is automatically tuned.(default: 0.005)
    :param float tau: target smoothing coefficient for soft update of target networks.
        (default: 0.05)
    :param Optional[BaseNoise] exploration_noise: the exploration noise. (default: None)
    :param int n_step: number of steps for multi-step learning. (default: 2)
    :param bool use_lagrangian: whether to use the Lagrangian constraint optimization.
        (default: True)
    :param List lagrangian_pid: the PID coefficients for the Lagrangian constraint
        optimization. (default: [0.05, 0.0005, 0.1])
    :param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian
        optimization. (default: np.inf)
    :param bool rescaling: whether use the rescaling trick for Lagrangian multiplier, see
        Alg. 1 in http://proceedings.mlr.press/v119/stooke20a/stooke20a.pdf
    :param float gamma: the discount factor for future rewards. (default: 0.99)
    :param bool reward_normalization: normalize rewards if True. (default: False)
    :param bool deterministic_eval: whether to use deterministic action selection during
        evaluation. (default: True)
    :param bool action_scaling: whether to scale the actions according to the action
        space bounds. (default: True)
    :param str action_bound_method: the method for handling actions that exceed the
        action space bounds ("clip" or other custom methods). (default: "clip")
    :param Optional[gym.Space] observation_space: the observation space of the
        environment. (default: None)
    :param Optional[gym.Space] action_space: the action space of the environment.
        (default: None)
    :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
        scheduler for the optimizer. (default: None)


    .. seealso::

        Please refer to :class:`~fsrl.policy.BasePolicy` and
        :class:`~fsrl.policy.LagrangianPolicy` for more detailed \
           hyperparameter explanations and usage.
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: Union[nn.Module, List[nn.Module]],
        safeset_net: nn.Module,
        expanderset_net: nn.Module,
        actor_optim: Optional[torch.optim.Optimizer],
        critic_optim: Optional[torch.optim.Optimizer],
        safeset_optim: torch.optim.Optimizer,
        expanderset_optim: torch.optim.Optimizer,
        logger: BaseLogger = BaseLogger(),
        # SAC specific arguments
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.005,
        tau: float = 0.05,
        exploration_noise: Optional[BaseNoise] = None,
        n_step: int = 2,
        # Lagrangian specific arguments
        use_lagrangian: bool = True,
        lagrangian_pid: Tuple = (0.05, 0.0005, 0.1),
        cost_limit: Union[List, float] = np.inf,
        rescaling: bool = True,
        # USPC
        USPC_L: float = 10.0,
        USPC_beta: float = 2.0,
        USPC_cov_scale: float = 1.0,
        USPC_sample_act_num: int = 64,
        expander_eta: float = 0.0,
        # Base policy common arguments
        gamma: float = 0.99,
        reward_normalization: bool = False,
        deterministic_eval: bool = True,
        action_scaling: bool = True,
        action_bound_method: str = "clip",
        observation_space: Optional[gym.Space] = None,
        action_space: Optional[gym.Space] = None,
        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
    ) -> None:
        super().__init__(
            actor,
            critics,
            None,
            logger,
            use_lagrangian,
            lagrangian_pid,
            cost_limit,
            rescaling,
            gamma,
            10000,
            reward_normalization,
            deterministic_eval,
            action_scaling,
            action_bound_method,
            observation_space,
            action_space,
            lr_scheduler,
        )

        # we need to do this since we model the ensemble with M identical constraints in the agent class
        if self.use_lagrangian:
            # replace M lag optimizers with a single one
            self.lag_optims = [LagrangianOptimizer(lagrangian_pid)]
            # ensure a single cost limit entry
            self.cost_limit = [cost_limit if np.isscalar(cost_limit) else cost_limit[0]]

        self.actor_optim = actor_optim
        self.critics_old = deepcopy(self.critics)
        self.critics_old.eval()
        self.critics_optim = critic_optim

        self.safeset_net = safeset_net
        self.expanderset_net = expanderset_net
        self.safeset_net_old = deepcopy(self.safeset_net)
        self.expanderset_net_old = deepcopy(self.expanderset_net)
        self.safeset_net_old.eval()
        self.expanderset_net_old.eval()
        self.safeset_optim = safeset_optim
        self.expanderset_optim = expanderset_optim

        assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
        self.tau = tau
        self._is_auto_alpha = False
        self._alpha: Union[float, torch.Tensor]
        if isinstance(alpha, tuple):
            self._is_auto_alpha = True
            self._target_entropy, self._log_alpha, self._alpha_optim = alpha
            assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
            self._alpha = self._log_alpha.detach().exp()
        else:
            self._alpha = alpha
        self._noise = exploration_noise
        self._n_step = n_step
        self.__eps = np.finfo(np.float32).eps.item()

        self.device = next(self.actor.parameters()).device
        self.dtype = next(self.actor.parameters()).dtype

        # USPC init
        self._USPC_sample_act_num = USPC_sample_act_num
        self._USPC_cov_scale = USPC_cov_scale
        self._USPC_L = USPC_L
        self._USPC_beta = USPC_beta
        self._expander_eta = expander_eta
        print("expander eta", self._expander_eta)

    # override
    def compute_nstep_returns(
        self,
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        n_step: int = 1,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param Batch batch: a data batch, which is equal to buffer[indice].
        :param ReplayBuffer buffer: the data buffer.
        :param ndarray indice: the sampled batch indices in the buffer.
        :param function target_q_fn: a function which compute target Q value of
            "obs_next" given data buffer and wanted indices.
        :param int n_step: the number of estimation step, should be an int greater than
            0. Default to 1.

        :return: a Batch. The result will be stored in batch.returns as a torch.Tensor
            with the same shape as target_q_fn's return tensor.
        """
        metrics = self.get_metrics(buffer)

        if len(metrics) == 2 and self.critics_num > 2:
            cost_metric = metrics[1]
            metrics = [metrics[0]] + [cost_metric] * (self.critics_num - 1)

        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)  # (nstep, bsz)
        # terminal indicates buffer indexes nstep after 'indice', and are truncated at
        # the end of each episode
        terminal = indices[-1]
        # (bsz, 1)
        value_mask = BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        with torch.no_grad():
            # list of q values, each (bsz, ?)
            target_q_list = target_q_fn(buffer, terminal)
        returns = []
        for i in range(self.critics_num):
            target_q = to_numpy(target_q_list[i].reshape(bsz, -1)) * value_mask

            target_q = nstep_return(
                metrics[i], end_flag, target_q, indices, self._gamma, n_step
            )
            returns.append(to_torch_as(target_q, target_q_list[i]))

        batch.rets = torch.stack(returns, dim=-1)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_list[0])
        return batch

    def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
        """Set the exploration noise."""
        self._noise = noise

    def train(self, mode: bool = True):
        """Set the module in training mode, except for the target network."""
        self.training = mode
        self.actor.train(mode)
        self.critics.train(mode)
        self.safeset_net.train(mode)
        self.expanderset_net.train(mode)
        return self

    def sync_weight(self) -> None:
        """Soft-update the weight for the target network."""
        self.soft_update(self.critics_old, self.critics, self.tau)
        self.soft_update(self.safeset_net_old, self.safeset_net, self.tau)

    def _target_q(
        self, buffer: ReplayBuffer, indices: np.ndarray
    ) -> List[torch.Tensor]:
        batch = buffer[indices]  # batch.obs_next: s_{t+n}
        obs_next_result = self(batch, input="obs_next")
        act = obs_next_result.act
        log_prob = obs_next_result.log_prob
        target_q_list = []
        for i in range(self.critics_num):
            target_q, _ = self.critics_old[i].predict(batch.obs_next, act)
            target_q_list.append(target_q - self._alpha * log_prob)
        return target_q_list

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        batch = self.compute_nstep_returns(
            batch, buffer, indices, self._target_q, self._n_step
        )
        return batch

    def pre_update_fn(self, stats_train: Dict, **kwarg) -> None:
        if not self.use_lagrangian:
            return
        cost_val = stats_train.get("cost")
        # accept scalar, list/tuple, or numpy array
        if isinstance(cost_val, (list, tuple, np.ndarray)):
            cost_val = float(np.mean(cost_val))
        elif isinstance(cost_val, torch.Tensor):
            cost_val = float(cost_val.mean().detach().cpu().item())
        else:
            cost_val = float(cost_val)
        self.lag_optims[0].step(cost_val, float(self.cost_limit[0]))

    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        obs = batch[input]
        logits, hidden = self.actor(obs, state=state, info=batch.info)
        assert isinstance(logits, tuple)
        dist = Independent(Normal(*logits), 1)
        if self._deterministic_eval and not self.training:
            act = logits[0]
        else:
            act = dist.rsample()
        log_prob = dist.log_prob(act).unsqueeze(-1)
        # apply correction for Tanh squashing when computing logprob from Gaussian You
        # can check out the original SAC paper (arXiv 1801.01290): Eq 21. in appendix C
        # to get some understanding of this equation.
        squashed_action = torch.tanh(act)
        log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
            -1, keepdim=True
        )
        return Batch(
            logits=logits,
            act=squashed_action,
            state=hidden,
            dist=dist,
            log_prob=log_prob,
        )

    def critics_loss(
        self, batch: Batch, critics: torch.nn.Module, optimizer: torch.optim.Optimizer
    ) -> Tuple[torch.Tensor, dict]:
        """A simple wrapper script for updating critic network."""
        weight = getattr(batch, "weight", 1.0)
        loss_critic = 0
        td_average = 0
        stats_critic = {}
        for i in range(self.critics_num):
            target_q = batch.rets[..., i].flatten()
            current_q_list = critics[i](batch.obs, batch.act)
            loss_i = 0
            for j in range(len(current_q_list)):
                td = current_q_list[j].flatten() - target_q
                td_average += td
                loss_i += (td.pow(2) * weight).mean()

            loss_critic += loss_i
            stats_critic["loss/loss_q" + str(i)] = loss_i.item()
            stats_critic["estep/val_q" + str(i)] = torch.mean(target_q).item()

        optimizer.zero_grad()
        loss_critic.backward()
        optimizer.step()
        td_average /= self.critics_num  # * 2
        stats_critic["loss/q_total"] = loss_critic.item()
        return td_average, stats_critic

    def policy_loss(self, batch: Batch, **kwarg):
        obs_result = self(batch)
        act = obs_result.act

        # normal loss
        current_q_list = self.critics[0](batch.obs, act)
        if not isinstance(current_q_list, (tuple, list)):
            current_q_list = [current_q_list]
        if len(current_q_list) >= 2:
            current_q = torch.min(current_q_list[0], current_q_list[1]).flatten()
        else:  # single critic case
            current_q = current_q_list[0].flatten()

        loss_actor_rew = (
            self._alpha * obs_result.log_prob.flatten() - current_q
        ).mean()
        # compute safety loss using safeset_net instead of critics[i]
        values = []
        if self.use_lagrangian:
            self.safeset_net.eval()
            safeset_pred, _ = self.safeset_net.predict(batch.obs, act)
            self.safeset_net.train()
            values.append(safeset_pred.flatten())
        loss_actor_safety, stats_actor = self.safety_loss(values)

        rescaling = stats_actor["loss/rescaling"]
        loss_actor_total = rescaling * (loss_actor_rew + loss_actor_safety)

        self.actor_optim.zero_grad()
        loss_actor_total.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            # please take a look at issue #258 if you'd like to change this line
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()
            stats_actor.update(
                {
                    "loss/alpha_loss": alpha_loss.item(),
                    "loss/alpha_value": self._alpha.item(),
                }
            )

        stats_actor.update(
            {
                "loss/actor_rew": loss_actor_rew.item(),
                "loss/actor_total": loss_actor_total.item(),
            }
        )
        return loss_actor_total, stats_actor

    def safeset_net_loss(
        self,
        batch: Batch,
        sample_act: torch.Tensor,
    ) -> Dict[str, float]:
        """Update the safeset network via the USPC-style two-certificate regression."""
        stats_safeset = {}

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=self.dtype
        )  # (B, ds)
        act = torch.as_tensor(
            batch.act, device=self.device, dtype=self.dtype
        )  # (B, da)

        K_pi, B, da = sample_act.shape
        K_loc = K_pi // 2
        K_glob = K_pi // 2
        K_all = K_pi + K_loc + K_glob
        ds = obs.shape[-1]
        act_low = torch.as_tensor(
            self.action_space.low, device=obs.device, dtype=self.dtype
        )
        act_high = torch.as_tensor(
            self.action_space.high, device=obs.device, dtype=self.dtype
        )
        with torch.no_grad():
            old_out = self(batch, input="obs")
            dist_old = old_out.dist
            mu_old = dist_old.mean  # (B, da)
            std_old = dist_old.stddev  # (B, da)

        obs_flat = (
            obs.unsqueeze(0).expand(K_pi, -1, -1).reshape(-1, ds)
        )  # (K_pi * B, ds)
        sample_act_flat = sample_act.reshape(-1, da)  # (K_pi * B, da)

        local_samples = mu_old.unsqueeze(0) + (
            self._USPC_cov_scale * std_old
        ) * torch.randn(K_loc, B, da, device=obs.device, dtype=self.dtype)
        local_samples = local_samples.clamp(min=act_low, max=act_high)  # (K_loc, B, da)

        global_samples = (
            torch.rand(K_glob, B, da, device=obs.device, dtype=self.dtype)
            * (act_high - act_low)
            + act_low
        )
        global_samples = global_samples.clamp(
            min=act_low, max=act_high
        )  # (K_glob, B, da)

        nbrs = torch.cat(
            [sample_act, local_samples, global_samples], dim=0
        )  # (K_all, B, da)

        obs_flat_all = (
            obs.unsqueeze(0).expand(K_all, -1, -1).reshape(-1, ds)
        )  # (K_all * B, ds)
        with torch.no_grad():
            q_list = []
            for i in range(1, self.critics_num):
                was_training = self.critics[i].training
                self.critics[i].eval()
                q_i, _ = self.critics[i].predict(
                    obs_flat_all, nbrs.reshape(-1, da)
                )  # (K_all*B, 1)
                q_list.append(q_i.squeeze(-1))  # (K_all*B,)
                if was_training:
                    self.critics[i].train()

        q_vals = torch.stack(q_list, dim=1)  # (K_all*B, m)
        q_mean = q_vals.mean(dim=1)  # (K_all*B,)
        q_std = q_vals.std(dim=1) if q_vals.shape[1] > 1 else 0.0
        q_ucb = (q_mean + self._USPC_beta * q_std).reshape(K_all, B)  # (K_all, B)

        nbrs_BA = nbrs.permute(1, 0, 2).contiguous()  # (B, K_all, da)
        cands_BK = sample_act.permute(1, 0, 2).contiguous()  # (B, K_pi,  da)
        dists_BAK = torch.cdist(nbrs_BA, cands_BK)  # (B, K_all, K_pi)
        q_ucb_BA = q_ucb.permute(1, 0).unsqueeze(-1)  # (B, K_all, 1)

        h = getattr(self, "_USPC_threshold", None)
        safeset_net_old = getattr(self, "safeset_net_old", None)

        obj_BAK = q_ucb_BA + self._USPC_L * dists_BAK  # (B, K_all, K_pi)

        if (h is not None) and (safeset_net_old is not None):
            with torch.no_grad():
                f_old_flat = safeset_net_old(obs_flat_all, nbrs.reshape(-1, da))[
                    0
                ].squeeze(
                    -1
                )  # (K_all*B,)
                f_old_BA = f_old_flat.reshape(K_all, B).permute(1, 0)  # (B, K_all)
            safe_mask_BA = f_old_BA <= h  # (B, K_all)

            masked_obj_BAK = obj_BAK.masked_fill(
                ~safe_mask_BA.unsqueeze(-1), float("inf")
            )
            min_masked_BK = masked_obj_BAK.min(dim=1).values  # (B, K_pi)

            has_safe = safe_mask_BA.any(dim=1, keepdim=True)  # (B,1)
            y_BK = torch.where(
                has_safe, min_masked_BK, obj_BAK.min(dim=1).values
            )  # (B, K_pi)
        else:
            y_BK = obj_BAK.min(dim=1).values  # (B, K_pi)

        y_flat = y_BK.permute(1, 0).reshape(-1)  # (K_pi*B,)

        safe_pred_flat = self.safeset_net(obs_flat, sample_act_flat)[0].squeeze(
            -1
        )  # (K_pi*B,)
        loss = ((safe_pred_flat - y_flat) ** 2).mean()

        self.safeset_optim.zero_grad(set_to_none=True)
        loss.backward()
        self.safeset_optim.step()

        witness_idx_BK = obj_BAK.argmin(dim=1)  # (B, K_pi)
        self_witness_frac = (witness_idx_BK < K_pi).float().mean().item()

        stats_safeset["loss/safeset"] = loss.item()
        stats_safeset["diag/self_witness_frac"] = self_witness_frac
        stats_safeset["diag/mean_ucb_anchor"] = q_ucb.mean().item()
        stats_safeset["diag/mean_dist"] = dists_BAK.mean().item()
        stats_safeset["diag/ensemble_std_mean"] = (
            (q_std if isinstance(q_std, torch.Tensor) else torch.tensor(q_std))
            .mean()
            .item()
        )

        return stats_safeset

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic
        td, stats_critic = self.critics_loss(batch, self.critics, self.critics_optim)
        batch.weight = td  # prio-buffer

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        K = self._USPC_sample_act_num
        B = obs.shape[0]
        da = batch.act.shape[-1]
        ds = obs.shape[-1]

        with torch.no_grad():
            old_result = self(batch, input="obs")
            old_dist = old_result.dist  # (B, da)
            sample_act: torch.Tensor = old_dist.sample((K,))  # (K, B, da)

        # train safeset_net
        stats_safeset = self.safeset_net_loss(batch, sample_act)

        # actor (with safeset_net predictions instead of critics)
        loss_actor, stats_actor = self.policy_loss(batch)

        self.sync_weight()
        self.logger.store(**stats_actor)
        self.logger.store(**stats_critic)
        self.logger.store(**stats_safeset)

    def exploration_noise(
        self, act: Union[np.ndarray, Batch], batch: Batch
    ) -> Union[np.ndarray, Batch]:
        if self._noise is None:
            return act
        if isinstance(act, np.ndarray):
            return act + self._noise(act.shape)
        return act
