from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
from tianshou.exploration import BaseNoise, GaussianNoise
from torch import nn

from fsrl.policy.lagrangian_base import LagrangianPolicy
from fsrl.utils import BaseLogger
from fsrl.utils.optim_util import LagrangianOptimizer
from fsrl.policy import BasePolicy
from fsrl.policy.base_policy import nstep_return


class USPCDDPGLagrangian(LagrangianPolicy):
    """The Deep Deterministic Policy Gradient (DDPG) with PID Lagrangian and USPC.

    More details, please refer to https://arxiv.org/abs/1509.02971 (DDPG) and
    https://arxiv.org/abs/2007.03964 (PID Lagrangian).

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~fsrl.policy.BasePolicy`. (s -> logits)
    :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
    :param Optional[torch.optim.Optimizer] actor_optim: the optimizer for the actor
        network. Default is None.
    :param Optional[torch.optim.Optimizer] critic_optim: the optimizer for the critic
        network(s). Default is None.
    :param BaseLogger logger: the logger instance for logging training information.
        Default is DummyLogger.
    :param float tau: the soft update coefficient for updating target networks. Default
        is 0.05.
    :param Optional[BaseNoise] exploration_noise: the noise instance for exploration.
        Default is GaussianNoise(sigma=0.1).
    :param int n_step: the number of steps for multi-step bootstrap targets. Default is
        2.
    :param bool use_lagrangian: whether to use the Lagrangian constraint optimization.
        Default is True.
    :param List lagrangian_pid: the PID coefficients for the Lagrangian constraint
        optimization. Default is [0.05, 0.0005, 0.1].
    :param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian
        optimization. Default is np.inf.
    :param bool rescaling: whether to rescale the Lagrangian multiplier. Default is True.
    :param float gamma: the discount factor for future rewards. Default is 0.99.
    :param bool reward_normalization: normalize rewards if True. Default is False.
    :param bool deterministic_eval: whether to use deterministic action selection during
        evaluation. Default is True.
    :param bool action_scaling: whether to scale the actions according to the action
        space bounds. Default is True.
    :param str action_bound_method: the method for handling actions that exceed the
        action space bounds ("clip" or other custom methods). Default is "clip".
    :param Optional[gym.Space] observation_space: the observation space of the
        environment. Default is None.
    :param Optional[gym.Space] action_space: the action space of the environment. Default
        is None.
    :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
        scheduler for the optimizer. Default is None.

    .. seealso::

        Please refer to :class:`~fsrl.policy.BasePolicy` and
        :class:`~fsrl.policy.LagrangianPolicy` for more detailed \
           hyperparameter explanations and usage.
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: Union[nn.Module, List[nn.Module]],
        safeset_net: nn.Module,
        expanderset_net: nn.Module,
        actor_optim: Optional[torch.optim.Optimizer],
        critic_optim: Optional[torch.optim.Optimizer],
        safeset_optim: torch.optim.Optimizer,
        expanderset_optim: torch.optim.Optimizer,
        logger: BaseLogger = BaseLogger(),
        # DDPG specific arguments
        tau: float = 0.05,
        exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
        n_step: int = 2,
        # Lagrangian specific arguments
        use_lagrangian: bool = True,
        lagrangian_pid: Tuple = (0.05, 0.0005, 0.1),
        cost_limit: Union[List, float] = np.inf,
        rescaling: bool = True,
        # USPC
        USPC_L: float = 10.0,
        USPC_beta: float = 2.0,
        USPC_cov_scale: float = 1.0,
        USPC_sample_act_num: int = 64,
        expander_eta: float = 0.0,
        # Base policy common arguments
        gamma: float = 0.99,
        reward_normalization: bool = False,
        deterministic_eval: bool = True,
        action_scaling: bool = True,
        action_bound_method: str = "clip",
        observation_space: Optional[gym.Space] = None,
        action_space: Optional[gym.Space] = None,
        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
    ) -> None:
        super().__init__(
            actor,
            critics,
            None,
            logger,
            use_lagrangian,
            lagrangian_pid,
            cost_limit,
            rescaling,
            gamma,
            99999,
            reward_normalization,
            deterministic_eval,
            action_scaling,
            action_bound_method,
            observation_space,
            action_space,
            lr_scheduler,
        )

        # we need to do this since we model the ensemble with M identical constraints in the agent class
        if self.use_lagrangian:
            # replace M lag optimizers with a single one
            self.lag_optims = [LagrangianOptimizer(lagrangian_pid)]
            # ensure a single cost limit entry
            self.cost_limit = [cost_limit if np.isscalar(cost_limit) else cost_limit[0]]

        self.actor_old = deepcopy(self.actor)
        self.actor_old.eval()
        self.actor_optim = actor_optim
        self.critics_old = deepcopy(self.critics)
        self.critics_old.eval()
        self.critics_optim = critic_optim

        self.safeset_net = safeset_net
        self.expanderset_net = expanderset_net
        self.safeset_net_old = deepcopy(self.safeset_net)
        self.expanderset_net_old = deepcopy(self.expanderset_net)
        self.safeset_net_old.eval()
        self.expanderset_net_old.eval()
        self.safeset_optim = safeset_optim
        self.expanderset_optim = expanderset_optim

        assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
        self.tau = tau
        self._noise = exploration_noise
        self._n_step = n_step

        self.device = next(self.actor.parameters()).device
        self.dtype = next(self.actor.parameters()).dtype

        # USPC init
        self._USPC_sample_act_num = USPC_sample_act_num
        self._USPC_cov_scale = USPC_cov_scale
        self._USPC_L = USPC_L
        self._USPC_beta = USPC_beta
        self._expander_eta = expander_eta
        print("expander eta", self._expander_eta)

        # needed since DDPG does not use stochastic policy networks
        if exploration_noise is not None:
            self._exploration_std = exploration_noise._sigma
        else:
            self._exploration_std = 0.1  # default fallback

    # override
    def compute_nstep_returns(
        self,
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        n_step: int = 1,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param Batch batch: a data batch, which is equal to buffer[indice].
        :param ReplayBuffer buffer: the data buffer.
        :param ndarray indice: the sampled batch indices in the buffer.
        :param function target_q_fn: a function which compute target Q value of
            "obs_next" given data buffer and wanted indices.
        :param int n_step: the number of estimation step, should be an int greater than
            0. Default to 1.

        :return: a Batch. The result will be stored in batch.returns as a torch.Tensor
            with the same shape as target_q_fn's return tensor.
        """
        metrics = self.get_metrics(buffer)

        if len(metrics) == 2 and self.critics_num > 2:
            cost_metric = metrics[1]
            metrics = [metrics[0]] + [cost_metric] * (self.critics_num - 1)

        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)
        terminal = indices[-1]
        # (bsz, 1)
        value_mask = BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        with torch.no_grad():
            # list of q values, each (bsz, ?)
            target_q_list = target_q_fn(buffer, terminal)
        returns = []
        for i in range(self.critics_num):
            target_q = to_numpy(target_q_list[i][0].reshape(bsz, -1)) * value_mask

            target_q = nstep_return(
                metrics[i], end_flag, target_q, indices, self._gamma, n_step
            )
            returns.append(to_torch_as(target_q, target_q_list[i][0]))

        batch.rets = torch.stack(returns, dim=-1)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_list[0])
        return batch

    def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
        """Set the exploration noise."""
        self._noise = noise

    def train(self, mode: bool = True):
        """Set the module in training mode, except for the target network."""
        self.training = mode
        self.actor.train(mode)
        self.critics.train(mode)
        self.safeset_net.train(mode)
        self.expanderset_net.train(mode)
        return self

    def sync_weight(self) -> None:
        """Soft-update the weight for the target network."""
        self.soft_update(self.actor_old, self.actor, self.tau)
        self.soft_update(self.critics_old, self.critics, self.tau)
        self.soft_update(self.safeset_net_old, self.safeset_net, self.tau)

    def _target_q(
        self, buffer: ReplayBuffer, indices: np.ndarray
    ) -> List[torch.Tensor]:
        batch = buffer[indices]  # batch.obs_next: s_{t+n}
        action_next = self(batch, model="actor_old", input="obs_next").act
        target_q = []
        for i in range(self.critics_num):
            target_q.append(self.critics_old[i](batch.obs_next, action_next))
        return target_q

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        batch = self.compute_nstep_returns(
            batch, buffer, indices, self._target_q, self._n_step
        )
        return batch

    def pre_update_fn(self, stats_train: Dict, **kwarg) -> None:
        if not self.use_lagrangian:
            return
        cost_val = stats_train.get("cost")
        # accept scalar, list/tuple, or numpy array
        if isinstance(cost_val, (list, tuple, np.ndarray)):
            cost_val = float(np.mean(cost_val))
        elif isinstance(cost_val, torch.Tensor):
            cost_val = float(cost_val.mean().detach().cpu().item())
        else:
            cost_val = float(cost_val)
        self.lag_optims[0].step(cost_val, float(self.cost_limit[0]))

    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "actor",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data.

        :return: A :class:`~tianshou.data.Batch` which has 2 keys:

            * ``act`` the action.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed
            explanation.
        """
        model = getattr(self, model)
        obs = batch[input]
        actions, hidden = model(obs, state=state, info=batch.info)
        return Batch(act=actions, state=hidden)

    def critics_loss(
        self, batch: Batch, critics: torch.nn.Module, optimizer: torch.optim.Optimizer
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """A simple wrapper script for updating critic network."""
        weight = getattr(batch, "weight", 1.0)
        loss_critic = 0
        td_average = 0
        stats_critic = {}
        for i in range(self.critics_num):
            current_q = critics[i](batch.obs, batch.act)[0].flatten()
            target_q = batch.rets[..., i].flatten()
            td = current_q - target_q
            td_average += td
            loss_i = (td.pow(2) * weight).mean()
            loss_critic += loss_i
            stats_critic["loss/q" + str(i)] = loss_i.item()
        optimizer.zero_grad()
        loss_critic.backward()
        optimizer.step()
        td_average /= self.critics_num
        stats_critic["loss/q_total"] = loss_critic.item()
        stats_critic["critic/val_q" + str(i)] = torch.mean(target_q).item()
        return td_average, stats_critic

    def policy_loss(self, batch: Batch, **kwarg):
        action = self(batch, model="actor", input="obs").act
        # normal loss
        loss_actor_rew = -self.critics[0](batch.obs, action)[0].mean()
        # compute safety loss

        values = []
        if self.use_lagrangian:
            self.safeset_net.eval()
            safeset_pred, _ = self.safeset_net.predict(batch.obs, action)
            self.safeset_net.train()
            values.append(safeset_pred.flatten())
        loss_actor_safety, stats_actor = self.safety_loss(values)

        rescaling = stats_actor["loss/rescaling"]
        loss_actor_total = rescaling * (loss_actor_rew + loss_actor_safety)

        self.actor_optim.zero_grad()
        loss_actor_total.backward()
        self.actor_optim.step()

        stats_actor.update(
            {
                "loss/actor_rew": loss_actor_rew.item(),
                "loss/actor_total": loss_actor_total.item(),
            }
        )
        return loss_actor_total, stats_actor

    def safeset_net_loss(
        self,
        batch: Batch,
        sample_act: torch.Tensor,
    ) -> Dict[str, float]:
        stats_safeset = {}

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=self.dtype
        )  # (B, ds)
        act = torch.as_tensor(
            batch.act, device=self.device, dtype=self.dtype
        )  # (B, da)

        K_pi, B, da = sample_act.shape
        K_loc = K_pi // 2
        K_glob = K_pi // 2
        K_all = K_pi + K_loc + K_glob
        ds = obs.shape[-1]
        act_low = torch.as_tensor(
            self.action_space.low, device=obs.device, dtype=self.dtype
        )
        act_high = torch.as_tensor(
            self.action_space.high, device=obs.device, dtype=self.dtype
        )

        with torch.no_grad():
            old_out = self(batch, model="actor_old", input="obs")
            mu_old = old_out.act  # (B, da)
            std_old = torch.ones_like(mu_old) * self._exploration_std  # (B, da)

        obs_flat = (
            obs.unsqueeze(0).expand(K_pi, -1, -1).reshape(-1, ds)
        )  # (K_pi * B, ds)
        sample_act_flat = sample_act.reshape(-1, da)  # (K_pi * B, da)

        local_samples = mu_old.unsqueeze(0) + (
            self._USPC_cov_scale * std_old
        ) * torch.randn(K_loc, B, da, device=obs.device, dtype=self.dtype)
        local_samples = local_samples.clamp(min=act_low, max=act_high)  # (K_loc, B, da)

        global_samples = (
            torch.rand(K_glob, B, da, device=obs.device, dtype=self.dtype)
            * (act_high - act_low)
            + act_low
        )
        global_samples = global_samples.clamp(
            min=act_low, max=act_high
        )  # (K_glob, B, da)

        nbrs = torch.cat(
            [sample_act, local_samples, global_samples], dim=0
        )  # (K_all, B, da)

        obs_flat_all = (
            obs.unsqueeze(0).expand(K_all, -1, -1).reshape(-1, ds)
        )  # (K_all * B, ds)
        with torch.no_grad():
            q_list = []
            for i in range(1, self.critics_num):
                was_training = self.critics[i].training
                self.critics[i].eval()
                q_i, _ = self.critics[i].predict(
                    obs_flat_all, nbrs.reshape(-1, da)
                )  # (K_all*B, 1)
                q_list.append(q_i.squeeze(-1))  # (K_all*B,)
                if was_training:
                    self.critics[i].train()

        q_vals = torch.stack(q_list, dim=1)  # (K_all*B, m)
        q_mean = q_vals.mean(dim=1)  # (K_all*B,)
        q_std = q_vals.std(dim=1) if q_vals.shape[1] > 1 else 0.0
        q_ucb = (q_mean + self._USPC_beta * q_std).reshape(K_all, B)  # (K_all, B)

        # Shapes for cdist: (B, K_all, da) x (B, K_pi, da) -> (B, K_all, K_pi)
        nbrs_BA = nbrs.permute(1, 0, 2).contiguous()  # (B, K_all, da)
        cands_BK = sample_act.permute(1, 0, 2).contiguous()  # (B, K_pi,  da)
        dists_BAK = torch.cdist(nbrs_BA, cands_BK)  # (B, K_all, K_pi)
        q_ucb_BA = q_ucb.permute(1, 0).unsqueeze(-1)  # (B, K_all, 1)

        h = getattr(self, "_USPC_threshold", None)
        safeset_net_old = getattr(self, "safeset_net_old", None)

        obj_BAK = q_ucb_BA + self._USPC_L * dists_BAK  # (B, K_all, K_pi)

        if (h is not None) and (safeset_net_old is not None):
            with torch.no_grad():
                f_old_flat = safeset_net_old(obs_flat_all, nbrs.reshape(-1, da))[
                    0
                ].squeeze(
                    -1
                )  # (K_all*B,)
                f_old_BA = f_old_flat.reshape(K_all, B).permute(1, 0)  # (B, K_all)
            safe_mask_BA = f_old_BA <= h  # (B, K_all)

            masked_obj_BAK = obj_BAK.masked_fill(
                ~safe_mask_BA.unsqueeze(-1), float("inf")
            )
            min_masked_BK = masked_obj_BAK.min(dim=1).values  # (B, K_pi)

            has_safe = safe_mask_BA.any(dim=1, keepdim=True)  # (B,1)
            y_BK = torch.where(
                has_safe, min_masked_BK, obj_BAK.min(dim=1).values
            )  # (B, K_pi)
        else:
            y_BK = obj_BAK.min(dim=1).values  # (B, K_pi)

        y_flat = y_BK.permute(1, 0).reshape(-1)  # (K_pi*B,)

        safe_pred_flat = self.safeset_net(obs_flat, sample_act_flat)[0].squeeze(
            -1
        )  # (K_pi*B,)
        loss = ((safe_pred_flat - y_flat) ** 2).mean()

        self.safeset_optim.zero_grad(set_to_none=True)
        loss.backward()
        self.safeset_optim.step()

        witness_idx_BK = obj_BAK.argmin(dim=1)  # (B, K_pi)

        self_witness_frac = (witness_idx_BK < K_pi).float().mean().item()

        stats_safeset["loss/safeset"] = loss.item()
        stats_safeset["diag/self_witness_frac"] = self_witness_frac
        stats_safeset["diag/mean_ucb_anchor"] = q_ucb.mean().item()
        stats_safeset["diag/mean_dist"] = dists_BAK.mean().item()
        stats_safeset["diag/ensemble_std_mean"] = (
            (q_std if isinstance(q_std, torch.Tensor) else torch.tensor(q_std))
            .mean()
            .item()
        )

        return stats_safeset

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic
        td, stats_critic = self.critics_loss(batch, self.critics, self.critics_optim)
        batch.weight = td  # prio-buffer

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)

        K = self._USPC_sample_act_num
        B = obs.shape[0]
        da = batch.act.shape[-1]
        ds = obs.shape[-1]

        with torch.no_grad():
            old_result = self(batch, model="actor_old", input="obs")
            mu_old = old_result.act  # (B, da)
            std_old = torch.ones_like(mu_old) * self._exploration_std  # (B, da)
            old_dist = torch.distributions.Normal(mu_old, std_old)
            sample_act: torch.Tensor = old_dist.sample((K,))  # (K, B, da)

        stats_safeset = self.safeset_net_loss(batch, sample_act)

        # actor
        loss_actor, stats_actor = self.policy_loss(batch)
        self.sync_weight()
        self.logger.store(**stats_actor)
        self.logger.store(**stats_critic)
        self.logger.store(**stats_safeset)

    def exploration_noise(
        self, act: Union[np.ndarray, Batch], batch: Batch
    ) -> Union[np.ndarray, Batch]:
        if self._noise is None:
            return act
        if isinstance(act, np.ndarray):
            return act + self._noise(act.shape)
        return act
