import time
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from torch import nn
from torch.distributions import kl_divergence

from fsrl.policy import BasePolicy
from fsrl.policy.base_policy import nstep_return
from fsrl.utils import BaseLogger


class USPCCVPO(BasePolicy):
    """Implementation of the Constrained Variational Policy Optimization (CVPO) with USPC.

    More details, please refer to https://arxiv.org/abs/2201.11927.

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~fsrl.policy.BasePolicy`. (s -> logits)
    :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
    :param torch.optim.Optimizer actor_optim: the optimizer for the actor network.
    :param torch.optim.Optimizer critic_optim: the optimizer for the critic network(s).
    :param gym.Space action_space: the action space of the environment.
    :param Type[torch.distributions.Distribution] dist_fn: the probability distribution
        function for sampling actions.
    :param int max_episode_steps: the maximum number of steps per episode for computing
        the step-wise qc threshold.
    :param Optional[BaseLogger] logger: the logger instance for logging training
        information. (default=DummyLogger)
    :param Union[List, float] cost_limit: the constraint limit(s) for the optimization.
        (default=np.inf)
    :param float tau: target smoothing coefficient for soft update of target networks.
        (default=0.05)
    :param float gamma: the discount factor for future rewards. (default=0.99)
    :param int n_step: number of steps for multi-step learning. (default=2)
    :param int estep_iter_num: the number of iterations for the E-step. (default=1)
    :param float estep_kl: the KL divergence threshold for the E-step. (default=0.02)
    :param float estep_dual_max: the maximum value for the dual variable in the E-step.
        (default=20)
    :param float estep_dual_lr: the learning rate for the dual variable in the E-step.
        (default=0.02)
    :param int sample_act_num: the number of actions to sample for the E-step.
        (default=16)
    :param int mstep_iter_num: the number of iterations for the M-step. (default=1)
    :param float mstep_kl_mu: the KL divergence threshold for the M-step (mean).
        (default=0.005)
    :param float mstep_kl_std: the KL divergence threshold for the M-step (standard
        deviation). (default=0.0005)
    :param float mstep_dual_max: the maximum value for the dual variable in the M-step.
        (default=0.5)
    :param float mstep_dual_lr: the learning rate for the dual variable in the M-step.
        (default=0.1)
    :param bool deterministic_eval: whether to use deterministic action selection during
        evaluation. (default=True)
    :param bool action_scaling: whether to scale the actions according to the action
        space bounds. (default=True)
    :param str action_bound_method: the method for handling actions that exceed the
        action space bounds ("clip" or other custom methods). (default="clip")
    :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
        scheduler for the optimizer.

    .. seealso::

        Please refer to :class:`~fsrl.policy.BasePolicy` for more detailed hyperparameter
        explanations and usage.
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: Union[nn.Module, List[nn.Module]],
        safeset_net: nn.Module,
        expanderset_net: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic_optim: torch.optim.Optimizer,
        safeset_optim: torch.optim.Optimizer,
        expanderset_optim: torch.optim.Optimizer,
        action_space: gym.Space,
        # CVPO specific arguments
        dist_fn: Type[torch.distributions.Distribution],
        max_episode_steps: int,
        logger: Optional[BaseLogger] = BaseLogger(),
        cost_limit: Union[List, float] = np.inf,
        tau: float = 0.05,
        gamma: float = 0.99,
        n_step: int = 2,
        # E-step
        estep_iter_num: int = 1,
        estep_kl: float = 0.02,
        estep_dual_max: float = 20,
        estep_dual_lr: float = 0.02,
        sample_act_num: int = 16,
        # M-step
        mstep_iter_num: int = 1,
        mstep_kl_mu: float = 0.005,
        mstep_kl_std: float = 0.0005,
        mstep_dual_max: float = 0.5,
        mstep_dual_lr: float = 0.1,
        # USPC
        USPC_L: float = 10.0,
        USPC_beta: float = 2.0,
        USPC_cov_scale: float = 1.0,
        USPC_sample_act_num: int = 64,
        expander_eta: float = 0.0,
        # other param
        deterministic_eval: bool = True,
        action_scaling: bool = True,
        action_bound_method: str = "clip",
        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
    ) -> None:
        super().__init__(
            actor=actor,
            critics=critics,
            dist_fn=dist_fn,
            logger=logger,
            gamma=gamma,
            deterministic_eval=deterministic_eval,
            action_scaling=action_scaling,
            action_bound_method=action_bound_method,
            action_space=action_space,
            lr_scheduler=lr_scheduler,
        )

        self.actor_old = deepcopy(self.actor)
        self.actor_old.eval()
        self.actor_optim = actor_optim
        self.critics_old = deepcopy(self.critics)
        self.critics_old.eval()
        self.critics_optim = critic_optim

        self.safeset_net = safeset_net
        self.expanderset_net = expanderset_net
        self.safeset_net_old = deepcopy(self.safeset_net)
        self.expanderset_net_old = deepcopy(self.expanderset_net)
        self.safeset_net_old.eval()
        self.expanderset_net_old.eval()
        self.safeset_optim = safeset_optim
        self.expanderset_optim = expanderset_optim

        self.device = next(self.actor.parameters()).device
        self.dtype = next(self.actor.parameters()).dtype
        self.cost_limit = (
            [cost_limit] * (self.critics_num - 1)
            if np.isscalar(cost_limit)
            else cost_limit
        )

        self.max_episode_steps = max_episode_steps
        # qc threshold in the E-step
        self.qc_thres = [
            c
            * (1 - self._gamma**self.max_episode_steps)
            / (1 - self._gamma)
            / self.max_episode_steps
            for c in self.cost_limit
        ]

        # E-step init
        self._estep_kl = estep_kl
        self._estep_iter_num = estep_iter_num
        self._estep_dual_max = estep_dual_max
        self._estep_dual_lr = estep_dual_lr
        self._sample_act_num = sample_act_num
        # the first dim is eta, others are lambda in the paper
        d = np.zeros(self.critics_num)
        d[0] = 1  # init eta to be 1
        self.estep_dual = torch.tensor(
            d, requires_grad=True, device=self.device, dtype=self.dtype
        )
        self.estep_optim = torch.optim.Adam([self.estep_dual], lr=self._estep_dual_lr)

        # M-step init
        self._mstep_kl_mu = mstep_kl_mu
        self._mstep_kl_std = mstep_kl_std
        self._mstep_iter_num = mstep_iter_num
        self._mstep_dual_max = mstep_dual_max
        self._mstep_dual_lr = mstep_dual_lr

        # USPC init
        self._USPC_sample_act_num = USPC_sample_act_num
        self._USPC_cov_scale = USPC_cov_scale
        self._USPC_L = USPC_L
        self._USPC_beta = USPC_beta
        self._expander_eta = expander_eta
        print("expander eta", self._expander_eta)

        self._estep_duration = 0
        self._mstep_duration = 0

        assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
        self.tau = tau
        self._discrete = True if self.action_type == "discrete" else False
        self._n_step = n_step
        self.__eps = np.finfo(np.float32).eps.item() * 10  # around 1e-6

    def update_cost_limit(self, cost_limit: float):
        """Update the cost limit threshold.

        :param float cost_limit: new cost threshold
        """
        self.cost_limit = (
            [cost_limit] * (self.critics_num - 1)
            if np.isscalar(cost_limit)
            else cost_limit
        )

        self.qc_thres = [
            c
            * (1 - self._gamma**self.max_episode_steps)
            / (1 - self._gamma)
            / self.max_episode_steps
            for c in self.cost_limit
        ]

    def pre_update_fn(self, **kwarg: Any) -> Any:
        """Init the mstep optimizer and dual variables."""
        self.mstep_dual_mu = torch.zeros(
            1, requires_grad=True, device=self.device, dtype=self.dtype
        )
        self.mstep_dual_std = torch.zeros(
            1, requires_grad=True, device=self.device, dtype=self.dtype
        )
        self.mstep_optim = torch.optim.Adam(
            [self.mstep_dual_mu, self.mstep_dual_std], lr=self._mstep_dual_lr
        )

    def post_update_fn(self, **kwarg: Any) -> Any:
        """Update the old actor network."""
        with torch.no_grad():
            self.actor_old.load_state_dict(self.actor.state_dict())

    def train(self, mode: bool = True):
        """Set the module in training mode, except for the target network."""
        self.training = mode
        self.actor.train(mode)
        self.critics.train(mode)
        self.safeset_net.train(mode)
        self.expanderset_net.train(mode)
        return self

    def sync_weight(self) -> None:
        """Soft-update the weight for the target network."""
        self.soft_update(self.critics_old, self.critics, self.tau)
        self.soft_update(self.safeset_net_old, self.safeset_net, self.tau)

    def _target_q(
        self, buffer: ReplayBuffer, indices: np.ndarray
    ) -> List[torch.Tensor]:
        batch = buffer[indices]
        obs_next_result = self(batch, model="actor", input="obs_next")
        act = obs_next_result.act
        target_q_list = []
        for i in range(self.critics_num):
            target_q, _ = self.critics_old[i].predict(batch.obs_next, act)
            target_q_list.append(target_q)
        return target_q_list

    def _compute_expander_bonus_for_indices(
        self,
        buffer: ReplayBuffer,
        indices: np.ndarray,
    ) -> tuple[np.ndarray, np.ndarray]:
        eta = float(getattr(self, "_expander_eta", 0.0))
        if eta <= 0.0:
            uniq = np.unique(indices.reshape(-1))
            return uniq, np.zeros(len(uniq), dtype=np.float32)

        K = int(
            getattr(
                self,
                "_expander_K",
                getattr(
                    self,
                    "_USPC_sample_act_num",
                    getattr(self, "_sample_act_num", 16),
                ),
            )
        )
        beta = float(getattr(self, "_USPC_beta", 1.0))
        L_lip = float(getattr(self, "_USPC_L", 0.0))
        use_uncertain_only = bool(getattr(self, "_expander_use_uncertain_only", False))

        h = getattr(self, "_USPC_threshold", None)
        if h is None:
            h = float(self.qc_thres[0]) if hasattr(self, "qc_thres") else 0.0
        h_t = torch.tensor(h, device=self.device, dtype=self.dtype)

        uniq = np.unique(indices.reshape(-1))
        bflat = buffer[uniq]
        obs_b = torch.as_tensor(
            bflat.obs, device=self.device, dtype=self.dtype
        )  # (Bf, ds)
        act_exec = torch.as_tensor(
            bflat.act, device=self.device, dtype=self.dtype
        )  # (Bf, da)
        Bf = obs_b.shape[0]
        if Bf == 0:
            return uniq, np.zeros(0, dtype=np.float32)

        with torch.no_grad():
            old_res = self(bflat, model="actor_old", input="obs")
            A = old_res.dist.sample((K,))  # (K, Bf, da)
            A_BK = A.permute(1, 0, 2).contiguous()  # (Bf, K, da)

            dists = torch.cdist(act_exec.unsqueeze(1), A_BK).squeeze(1)  # (Bf, K)

            obs_rep = (
                obs_b[None].expand(K, -1, -1).reshape(-1, obs_b.shape[-1])
            )  # (K*Bf, ds)
            A_flat = A.reshape(-1, act_exec.shape[-1])  # (K*Bf, da)
            safeset_eval = getattr(self, "safeset_net_old", None) or self.safeset_net
            f_old = safeset_eval(obs_rep, A_flat)[0].squeeze(-1)  # (K*Bf,)
            safe_old = (f_old <= h_t).reshape(K, Bf).permute(1, 0)  # (Bf, K)
            outside = ~safe_old  # (Bf, K)

            if use_uncertain_only:
                q_cand = []
                for i_c in range(1, self.critics_num):
                    qi, _ = self.critics[i_c].predict(obs_rep, A_flat)  # (K*Bf, 1)
                    q_cand.append(qi.squeeze(-1))
                q_cand = torch.stack(q_cand, dim=1)
                q_mean_c = q_cand.mean(dim=1)
                q_std_c = (
                    q_cand.std(dim=1)
                    if q_cand.shape[1] > 1
                    else torch.zeros_like(q_mean_c)
                )
                U_c = (
                    (q_mean_c + beta * q_std_c).reshape(K, Bf).permute(1, 0)
                )  # (Bf, K)
                L_c = (
                    (q_mean_c - beta * q_std_c).reshape(K, Bf).permute(1, 0)
                )  # (Bf, K)
                outside = outside & (L_c <= h_t) & (U_c > h_t)

            # LCB at executed action
            q_exec = []
            for i_c in range(1, self.critics_num):
                qi, _ = self.critics[i_c].predict(obs_b, act_exec)  # (Bf, 1)
                q_exec.append(qi.squeeze(-1))  # (Bf,)
            q_exec = torch.stack(q_exec, dim=1)  # (Bf, m_cost)
            q_mean = q_exec.mean(dim=1)
            q_std = (
                q_exec.std(dim=1) if q_exec.shape[1] > 1 else torch.zeros_like(q_mean)
            )
            lcb = q_mean - beta * q_std
            y_plus = torch.minimum(h_t, lcb)  # (Bf,)

            # Expansion condition vs candidates
            lhs = y_plus[:, None] + L_lip * dists  # (Bf, K)
            can_expand = (lhs <= h_t) & outside  # (Bf, K)

            cnt = can_expand.sum(dim=1)  # (Bf,)
            denom = outside.sum(dim=1).clamp(min=1)  # (Bf,)
            ratio = (cnt / denom) * eta  # (Bf,)
            bonus = ratio.detach().cpu().numpy().astype(np.float32)  # (Bf,)

        return uniq, bonus

    # override
    def compute_nstep_returns(
        self,
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        n_step: int = 1,
    ) -> Batch:
        # --- original setup ---
        metrics = self.get_metrics(buffer)
        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)  # (nstep, bsz)

        terminal = indices[-1]
        value_mask = BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)

        with torch.no_grad():
            target_q_list = target_q_fn(buffer, terminal)

        uniq_idx, bonus_vec = self._compute_expander_bonus_for_indices(buffer, indices)
        bonus_full = None
        if bonus_vec.size > 0 and float(getattr(self, "_expander_eta", 0.0)) > 0.0:
            rew_metric = metrics[0]
            bonus_full = np.zeros_like(rew_metric, dtype=np.float32)
            if bonus_full.ndim == 1:
                bonus_full[uniq_idx] += bonus_vec
            else:
                bonus_full[uniq_idx, 0] += bonus_vec

        returns = []
        for i in range(self.critics_num):
            metric_i = (
                metrics[0] if i == 0 else metrics[1]
            )  # reward for i==0, cost otherwise
            if (bonus_full is not None) and (i == 0):
                metric_i = metric_i + bonus_full  # shape-safe add

            target_q = to_numpy(target_q_list[i].reshape(bsz, -1)) * value_mask
            target_q = nstep_return(
                metric_i, buffer.done, target_q, indices, self._gamma, n_step
            )
            returns.append(to_torch_as(target_q, target_q_list[i]))

        batch.rets = torch.stack(returns, dim=-1)
        if hasattr(batch, "weight"):
            batch.weight = to_torch_as(batch.weight, target_q_list[0])

        # diagnostics
        if bonus_full is not None:
            bf = (
                bonus_full[uniq_idx]
                if bonus_full.ndim == 1
                else bonus_full[uniq_idx, 0]
            )
            self.logger.store(
                **{
                    "diag/intrinsic_bonus_mean": float(np.mean(bf)),
                    "diag/intrinsic_bonus_max": float(np.max(bf)),
                }
            )
        return batch

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        batch = self.compute_nstep_returns(
            batch, buffer, indices, self._target_q, self._n_step
        )
        return batch

    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "actor",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        model = getattr(self, model)
        obs = batch[input]
        logits, hidden = model(obs, state=state)
        if isinstance(logits, tuple):
            dist = self.dist_fn(*logits)
        else:
            dist = self.dist_fn(logits)
        if self._deterministic_eval and not self.training:
            if self.action_type == "discrete":
                act = logits.argmax(-1)
            elif self.action_type == "continuous":
                act = logits[0]
        else:
            act = dist.sample()
        return Batch(logits=logits, act=act, state=hidden, dist=dist)

    def critics_loss(
        self, batch: Batch, critics: torch.nn.Module, optimizer: torch.optim.Optimizer
    ) -> Tuple[torch.Tensor, dict]:
        """A simple wrapper script for updating critic network."""
        weight = getattr(batch, "weight", 1.0)
        loss_critic = 0
        td_average = 0
        stats_critic = {}
        for i in range(self.critics_num):
            target_q = batch.rets[..., i].flatten()
            current_q_list = critics[i](batch.obs, batch.act)
            loss_i = 0
            for j in range(len(current_q_list)):
                td = current_q_list[j].flatten() - target_q
                td_average += td
                loss_i += (td.pow(2) * weight).mean()

            loss_critic += loss_i
            stats_critic["loss/loss_q" + str(i)] = loss_i.item()
            stats_critic["estep/val_q" + str(i)] = torch.mean(target_q).item()
            if i >= 1:
                stats_critic["estep/thres_q" + str(i)] = self.qc_thres[i - 1]
        optimizer.zero_grad()
        loss_critic.backward()
        optimizer.step()
        td_average /= self.critics_num * 2
        stats_critic["loss/q_total"] = loss_critic.item()
        return td_average, stats_critic

    def safeset_net_loss(
        self,
        batch: Batch,
        sample_act: torch.Tensor,
        safeset_net: nn.Module,
        optimizer: torch.optim.Optimizer,
    ) -> Dict[str, float]:
        stats_safeset = {}

        # initialize some values
        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=self.dtype
        )  # (B, ds)
        act = torch.as_tensor(
            batch.act, device=self.device, dtype=self.dtype
        )  # (B, da)

        K_pi, B, da = sample_act.shape
        K_loc = K_pi // 2
        K_glob = K_pi // 2
        K_all = K_pi + K_loc + K_glob
        ds = obs.shape[-1]
        act_low = torch.as_tensor(
            self.action_space.low, device=obs.device, dtype=self.dtype
        )
        act_high = torch.as_tensor(
            self.action_space.high, device=obs.device, dtype=self.dtype
        )
        with torch.no_grad():
            old_out = self(batch, model="actor_old", input="obs")
            dist_old = old_out.dist
            mu_old = dist_old.mean  # (B, da)
            std_old = dist_old.stddev  # (B, da)

        obs_flat = (
            obs.unsqueeze(0).expand(K_pi, -1, -1).reshape(-1, ds)
        )  # (K_pi * B, ds)
        sample_act_flat = sample_act.reshape(-1, da)  # (K_pi * B, da)

        local_samples = mu_old.unsqueeze(0) + (
            self._USPC_cov_scale * std_old
        ) * torch.randn(K_loc, B, da, device=obs.device, dtype=self.dtype)
        local_samples = local_samples.clamp(min=act_low, max=act_high)  # (K_loc, B, da)

        global_samples = (
            torch.rand(K_glob, B, da, device=obs.device, dtype=self.dtype)
            * (act_high - act_low)
            + act_low
        )
        global_samples = global_samples.clamp(
            min=act_low, max=act_high
        )  # (K_glob, B, da)

        nbrs = torch.cat(
            [sample_act, local_samples, global_samples], dim=0
        )  # (K_all, B, da)

        obs_flat_all = (
            obs.unsqueeze(0).expand(K_all, -1, -1).reshape(-1, ds)
        )  # (K_all * B, ds)
        with torch.no_grad():
            q_list = []
            for i in range(1, self.critics_num):
                was_training = self.critics[i].training
                self.critics[i].eval()
                q_i, _ = self.critics[i].predict(
                    obs_flat_all, nbrs.reshape(-1, da)
                )  # (K_all*B, 1)
                q_list.append(q_i.squeeze(-1))  # (K_all*B,)
                if was_training:
                    self.critics[i].train()

        q_vals = torch.stack(q_list, dim=1)  # (K_all*B, m)
        q_mean = q_vals.mean(dim=1)  # (K_all*B,)
        q_std = q_vals.std(dim=1) if q_vals.shape[1] > 1 else 0.0
        q_ucb = (q_mean + self._USPC_beta * q_std).reshape(K_all, B)  # (K_all, B)

        nbrs_BA = nbrs.permute(1, 0, 2).contiguous()  # (B, K_all, da)
        cands_BK = sample_act.permute(1, 0, 2).contiguous()  # (B, K_pi,  da)
        dists_BAK = torch.cdist(nbrs_BA, cands_BK)  # (B, K_all, K_pi)
        q_ucb_BA = q_ucb.permute(1, 0).unsqueeze(-1)  # (B, K_all, 1)

        h = getattr(self, "_USPC_threshold", None)
        safeset_net_old = getattr(self, "safeset_net_old", None)

        obj_BAK = q_ucb_BA + self._USPC_L * dists_BAK  # (B, K_all, K_pi)

        if (h is not None) and (safeset_net_old is not None):
            with torch.no_grad():
                f_old_flat = safeset_net_old(obs_flat_all, nbrs.reshape(-1, da))[
                    0
                ].squeeze(
                    -1
                )  # (K_all*B,)
                f_old_BA = f_old_flat.reshape(K_all, B).permute(1, 0)  # (B, K_all)
            safe_mask_BA = f_old_BA <= h  # (B, K_all)

            masked_obj_BAK = obj_BAK.masked_fill(
                ~safe_mask_BA.unsqueeze(-1), float("inf")
            )
            min_masked_BK = masked_obj_BAK.min(dim=1).values  # (B, K_pi)

            has_safe = safe_mask_BA.any(dim=1, keepdim=True)  # (B,1)
            y_BK = torch.where(
                has_safe, min_masked_BK, obj_BAK.min(dim=1).values
            )  # (B, K_pi)
        else:
            y_BK = obj_BAK.min(dim=1).values  # (B, K_pi)

        y_flat = y_BK.permute(1, 0).reshape(-1)  # (K_pi*B,)

        safe_pred_flat = safeset_net(obs_flat, sample_act_flat)[0].squeeze(
            -1
        )  # (K_pi*B,)
        loss = ((safe_pred_flat - y_flat) ** 2).mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        witness_idx_BK = obj_BAK.argmin(dim=1)  # (B, K_pi)
        self_witness_frac = (witness_idx_BK < K_pi).float().mean().item()

        stats_safeset["loss/safeset"] = loss.item()
        stats_safeset["diag/self_witness_frac"] = self_witness_frac
        stats_safeset["diag/mean_ucb_anchor"] = q_ucb.mean().item()
        stats_safeset["diag/mean_dist"] = dists_BAK.mean().item()
        stats_safeset["diag/ensemble_std_mean"] = (
            (q_std if isinstance(q_std, torch.Tensor) else torch.tensor(q_std))
            .mean()
            .item()
        )

        return stats_safeset

    def expanderset_net_loss(
        self,
        batch: Batch,
        expanderset_net: nn.Module,
        optimizer: torch.optim.Optimizer,
    ) -> Dict[str, float]:

        stats = {}

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        B = obs.shape[0]
        da = batch.act.shape[-1]
        K = (
            self._USPC_sample_act_num
            if hasattr(self, "_USPC_sample_act_num")
            else self._sample_act_num
        )

        with torch.no_grad():
            old_result = self(batch, model="actor_old", input="obs")
            old_dist = old_result.dist
            sample_act: torch.Tensor = old_dist.sample((K,))  # (K, B, da)
            expanded_obs = obs[None, ...].expand(K, -1, -1)  # (K, B, ds)

        ds = obs.shape[-1]
        obs_flat = expanded_obs.reshape(-1, ds)  # (K*B, ds)
        acts_flat = sample_act.reshape(-1, da)  # (K*B, da)

        with torch.no_grad():
            q_list = []
            for i in range(1, self.critics_num):  # cost critics only
                was_training = self.critics[i].training
                self.critics[i].eval()
                qi, _ = self.critics[i].predict(obs_flat, acts_flat)  # (K*B, 1)
                q_list.append(qi.squeeze(-1))  # (K*B,)
                if was_training:
                    self.critics[i].train()

        q_vals = torch.stack(q_list, dim=1)  # (K*B, m_cost)
        q_mean = q_vals.mean(dim=1)  # (K*B,)
        q_std = q_vals.std(dim=1) if q_vals.shape[1] > 1 else torch.zeros_like(q_mean)
        U = (q_mean + self._USPC_beta * q_std).reshape(K, B)  # (K, B)
        Lc = (q_mean - self._USPC_beta * q_std).reshape(K, B)  # (K, B)

        h = getattr(self, "_USPC_threshold", None)
        if h is None:
            h = float(self.qc_thres[0]) if hasattr(self, "qc_thres") else 0.0
        h_t = torch.tensor(h, device=obs.device, dtype=q_mean.dtype)

        with torch.no_grad():
            if hasattr(self, "safeset_net_old") and self.safeset_net_old is not None:
                f_old_flat = self.safeset_net_old(obs_flat, acts_flat)[0].squeeze(
                    -1
                )  # (K*B,)
            else:
                f_old_flat = self.safeset_net(obs_flat, acts_flat)[0].squeeze(-1)
        f_old = f_old_flat.reshape(K, B)  # (K, B)

        maybe_mask = (Lc <= h_t) & (U > h_t)  # (K, B)

        acts_BK = sample_act.permute(1, 0, 2).contiguous()  # (B, K, da)
        dists = torch.cdist(acts_BK, acts_BK)  # (B, K, K)

        y_plus = torch.minimum(h_t, Lc.permute(1, 0))  # (B, K)

        L_lip = torch.tensor(self._USPC_L, device=obs.device, dtype=q_mean.dtype)
        lhs = y_plus.unsqueeze(-1) + L_lip * dists  # (B, K, K)
        cond = (lhs <= h_t) & maybe_mask.permute(1, 0).unsqueeze(1)  # (B, K, K)

        counts = cond.sum(dim=-1)  # (B, K)
        M_sizes = maybe_mask.permute(1, 0).sum(dim=-1)  # (B,)
        denom = torch.clamp(M_sizes, min=1)
        targets = counts / denom.unsqueeze(-1)  # (B, K) in [0,1]

        preds_flat = expanderset_net(obs_flat, acts_flat)[0].squeeze(-1)  # (K*B,)
        preds = preds_flat.reshape(K, B).permute(1, 0)  # (B, K)
        loss = torch.mean((preds - targets) ** 2)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        pos_rate = (targets > 0).float().mean().item()
        stats["loss/expanderset"] = loss.item()
        stats["expanderset/target_mean"] = targets.mean().item()
        stats["expanderset/pos_rate"] = pos_rate
        stats["expanderset/maybe_frac"] = (M_sizes > 0).float().mean().item()
        stats["expanderset/mean_M_size"] = M_sizes.float().mean().item()
        stats["expanderset/mean_y_plus"] = y_plus.mean().item()
        stats["expanderset/mean_dist"] = dists.mean().item()
        stats["expanderset/lip"] = float(self._USPC_L)
        stats["expanderset/beta"] = float(self._USPC_beta)
        stats["expanderset/h"] = float(h)

        return stats

    def _estep_dual_loss(self, q_values):
        eta = self.estep_dual[0]
        K = q_values[0].shape[-1]
        loss = eta * self._estep_kl
        combined_q = q_values[0].detach()  # (B, K)
        for i in range(1, len(q_values)):
            combined_q -= (
                self.estep_dual[i] * q_values[i].detach()
            )  # q_values[1:] should contain safeset net outputs
            loss += self.estep_dual[i] * self.qc_thres[i - 1]
        loss += eta * torch.mean(torch.logsumexp(combined_q / eta, dim=1) - np.log(K))

        with torch.no_grad():
            eta = self.estep_dual[0]
            K = q_values[0].shape[-1]
            comb = q_values[0].detach().clone()
            for i in range(1, len(q_values)):
                comb -= self.estep_dual[i] * q_values[i].detach()
            w = torch.softmax(comb / eta, dim=1)  # (B, K)
            mean_sur = (w * q_values[1].detach()).sum(dim=1).mean().item()
            thr = float(self.qc_thres[0])
            self.logger.store("debug_estep", mean_surrogate=mean_sur)
            self.logger.store("debug_estep", thres=thr)
            self.logger.store("debug_estep", grad_sign=thr - mean_sur)
        return loss

    @staticmethod
    def gaussian_kl(
        mu_old: torch.Tensor, std_old: torch.Tensor, mu: torch.Tensor, std: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Decoupled KL between two multivariate Gaussians with diagonal covariance.

        See https://arxiv.org/pdf/1812.02256.pdf Sec. 4.2.1 for details. kl_mu = KL(
        pi(mu_old, std_old) || pi(mu, std_old) ) kl_std = KL( pi(mu_old, std_old) ||
        pi(mu_old, std) )

        :param mu_old: (B, n)
        :param mu: (B, n)
        :param std_old: (B, n)
        :param std: (B, n)
        :return: kl_mu, kl_std: scalar mean and covariance terms of the KL
        """
        var_old, var = std_old**2, std**2
        # for numerical stability
        var_old = torch.clamp_min(var_old, 1e-6)
        var = torch.clamp_min(var, 1e-6)

        # note, this kl's demoninator is the old var rather than the new var
        kl_mu = 0.5 * (mu_old - mu) ** 2 / var_old
        kl_mu = torch.sum(kl_mu, dim=-1).mean()

        kl_std = 0.5 * (torch.log(var / var_old) + var_old / var - 1)
        kl_std = torch.sum(kl_std, dim=-1).mean()  # Sum over the dimensions

        return kl_mu, kl_std

    def policy_loss(
        self,
        batch: Batch,
        old_result,
        sample_act: torch.Tensor,
        q_values: List[torch.Tensor],
        **kwarg,
    ):

        # E-step begin
        t_start = time.time()

        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        # for continuous action space, sample K particles
        K = self._sample_act_num
        B = obs.shape[0]
        da = batch.act.shape[-1]
        ds = obs.shape[-1]

        # optimize
        for _ in range(self._estep_iter_num):
            self.estep_optim.zero_grad()
            estep_loss = self._estep_dual_loss(q_values)
            estep_loss.backward()
            self.estep_optim.step()
            self.logger.store(tab="loss", estep_loss=estep_loss.item())
        self.estep_dual.data.clamp_(min=self.__eps, max=self._estep_dual_max)
        # detach the estep dual variable for M-step
        estep_dual = []
        for i in range(2):  # TODO hardcoded to 2 because of ensemble stuff
            estep_dual.append(self.estep_dual[i].detach().item())
            self.logger.store(**{"estep/dual" + str(i): estep_dual[i]})

        # compute the optimal non-parametric variational distribution
        optimal_q = q_values[0].T  # (K, B)
        for i in range(1, 2):  # TODO hardcoded to 2 because of ensemble stuff
            optimal_q -= estep_dual[i] * q_values[i].T
        optimal_q = torch.softmax((optimal_q) / estep_dual[0], dim=0).detach()  # (K, B)

        t_estep = time.time()
        self._estep_duration += t_estep - t_start
        self.logger.store(tab="estep", estep_time=self._estep_duration)

        # M-step begin

        mu_old, std_old = old_result.logits
        mu_old, std_old = mu_old.detach(), std_old.detach()
        for _ in range(self._mstep_iter_num):
            result = self(batch, model="actor", input="obs")

            # MLE loss
            mu, std = result.logits
            dist1 = self.dist_fn(mu, std_old)
            dist2 = self.dist_fn(mu_old, std)
            likelihood = dist1.expand((K, B)).log_prob(sample_act) + dist2.expand(
                (K, B)
            ).log_prob(
                sample_act
            )  # (K, B)
            loss_mle = -torch.mean(optimal_q * likelihood)

            # update dual variables to regularize the KL
            kl_mu, kl_std = self.gaussian_kl(mu_old, std_old, mu, std)
            mstep_dual_loss = (
                self.mstep_dual_mu * (self._mstep_kl_mu - kl_mu).detach()
                + self.mstep_dual_std * (self._mstep_kl_std - kl_std).detach()
            )
            self.mstep_optim.zero_grad()
            mstep_dual_loss.backward()
            self.mstep_optim.step()

            # KL loss
            dual_mu = np.clip(self.mstep_dual_mu.item(), 0.0, self._mstep_dual_max)
            dual_std = np.clip(self.mstep_dual_std.item(), 0.0, self._mstep_dual_max)
            loss_kl = dual_mu * (kl_mu - self._mstep_kl_mu) + dual_std * (
                kl_std - self._mstep_kl_std
            )

            loss_actor = loss_mle + loss_kl

            # optimize the policy network
            self.actor_optim.zero_grad()
            loss_actor.backward()
            self.actor_optim.step()

            entropy = torch.mean(dist1.entropy() + dist2.entropy()).item()

            self.logger.store(
                tab="mstep",
                mstep_kl_mu=kl_mu.item(),
                mstep_kl_std=kl_std.item(),
                mstep_loss_kl=loss_kl.item(),
                mstep_loss_mle=loss_mle.item(),
                mstep_loss_total=loss_actor.item(),
                mstep_dual_mu=dual_mu,
                mstep_dual_std=dual_std,
                entropy=entropy,
            )
        self._mstep_duration += time.time() - t_estep
        self.logger.store(tab="mstep", mstep_time=self._mstep_duration)

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic
        td, stats_critic = self.critics_loss(batch, self.critics, self.critics_optim)

        # sample actions & store q-values for safeset and policy steps
        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        # for continuous action space, sample K particles
        K = self._sample_act_num
        B = obs.shape[0]
        da = batch.act.shape[-1]
        ds = obs.shape[-1]

        with torch.no_grad():
            old_result = self(batch, model="actor_old", input="obs")
            old_dist = old_result.dist  # (B, da)
            sample_act: torch.Tensor = old_dist.sample((K,))  # (K, B, da)
            expanded_obs = obs[None, ...].expand(K, -1, -1)  # (K, B, ds)
            q_values = []

            target_q, _ = self.critics[0].predict(
                expanded_obs.reshape(-1, ds), sample_act.reshape(-1, da)
            )
            target_q = target_q.reshape(K, B)
            q_values.append(target_q.T)  # (critic_num, B, K)

            self.safeset_net.eval()
            target_q, _ = self.safeset_net.predict(
                expanded_obs.reshape(-1, ds), sample_act.reshape(-1, da)
            )
            self.safeset_net.train()
            target_q = target_q.reshape(K, B)
            q_values.append(target_q.T)  # (critic_num, B, K)

        stats_safeset_net = self.safeset_net_loss(
            batch=batch,
            sample_act=sample_act,
            safeset_net=self.safeset_net,
            optimizer=self.safeset_optim,
        )

        batch.weight = td  # prio-buffer
        # actor TODO FIXME
        self.policy_loss(
            batch=batch, old_result=old_result, sample_act=sample_act, q_values=q_values
        )

        self.sync_weight()
        self.logger.store(**stats_critic)
        self.logger.store(**stats_safeset_net)

    def get_extra_state(self):
        """Save the dual variables and their optimizers.

        This function is called when call the policy.state_dict(), see
        https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state
        """
        # if len(self.lag_optims): return [optim.state_dict() for optim in
        #     self.lag_optims] else: return None

    def set_extra_state(self, state):
        """Load the dual variables and their optimizers.

        This function is called from load_state_dict() to handle any extra state found
        within the state_dict.
        """
