# class _VariationalCriticEnsemble(SVGAgent):
#     def __init__(
#         self, *args, beta_init=1.0, lr_beta=0.2, target_improvement=0.1, **kwargs
#     ):
#         super().__init__(*args, **kwargs)
#         self.num_model_members = self.dynamics_model.num_members
#         if self.num_critic_ensemble != self.num_model_members:
#             raise ValueError(f"{self.num_critic_ensemble} != {self.num_model_members}")
#         self.num_critic_ensemble = 1
#         self.is_explore = False
#         size_exp = self.num_model_members, self.batch_size
#         self.done_act = self.done.unsqueeze(0).expand(size_exp)
#
#         self.beta = beta_init
#         self.lr_beta = lr_beta
#         self.target_improvement = target_improvement
#         self.beta_norm = None
#
#     def eval_rollout(
#         self,
#         batch_sa,
#         log_pis,
#         batch_masks,
#         rewards,
#         alpha,
#         last_sa,
#         discounts,
#         critic,
#         critic_target,
#     ):
#         with torch.no_grad():
#             q_values = critic_target(last_sa)
#             target_rewards = torch.cat([rewards, q_values[None, ..., 0]])
#             target_rewards[1:, ...].sub_(alpha.detach() * log_pis[1:, ...])
#             target_values = discounts[0, :, None, None] * target_rewards * batch_masks
#             target_values = torch.sum(target_values, 0)
#
#         deviation = discounts.shape[1] - discounts.shape[0]
#         pred_values = critic(batch_sa.detach())
#         pred_values = pred_values * batch_masks[:-deviation, :, None]
#         return target_values[..., None], pred_values
#
#     def pred_q_value(self, obs_action):
#         pred_q = self.critic(obs_action).t()[..., None]
#         if self.scaled_critic:
#             pred_q = pred_q * self._q_width + self._q_center
#         return pred_q
#
#     def pred_target_q_value(self, obs_action):
#         target_q = self.critic_target(obs_action).t()[..., None]
#         if self.scaled_critic:
#             target_q = target_q * self._q_width + self._q_center
#         if self.bounded_critic:
#             target_q = self._q_ub - torch.relu(self._q_ub - target_q)
#             target_q = self._q_lb + torch.relu(target_q - self._q_lb)
#         return target_q
#
#     def actor_loss(
#         self,
#         ctx_modules,
#         batch_sa,
#         batch_mask,
#         log_pis,
#         rewards,
#         last_sa,
#         info,
#         log=False,
#     ):
#         pred_qs = ctx_modules.pred_q(batch_sa).squeeze(-1)
#         rewards = torch.cat([rewards, pred_qs.unsqueeze(0)])
#         rewards.sub_(ctx_modules.alpha.detach() * log_pis)
#         pred_qs = self.discount_mat[0, :, None, None] * rewards * batch_mask
#         pred_qs = pred_qs.sum(1)
#
#         # softmax_q = pred_qs.max(0).values
#         q_max = pred_qs.max(0, keepdim=True).values
#         softmax_q = (
#             q_max + ((pred_qs - q_max) / self.beta).exp().mean(0).log() * self.beta
#         )
#         loss_actor = -softmax_q.mean()
#
#         entropy = -log_pis[0].detach().mean()
#         self._info.update(**{KEY_ACTOR_LOSS: loss_actor.detach(), "entropy": entropy})
#
#         if self.learnable_alpha:
#             alpha_loss = -ctx_modules.alpha * (self.target_entropy - entropy)
#             loss_actor += alpha_loss
#
#             self._info.update(
#                 **{
#                     "alpha_loss": alpha_loss.detach(),
#                     "alpha_value": ctx_modules.alpha.detach(),
#                 }
#             )
#
#         # Take a SGD step
#         ctx_modules.actor_optimizer.zero_grad()
#         loss_actor.backward()
#
#         # Beta tuning
#         with torch.no_grad():
#             baseline = pred_qs.mean()
#             abs_baseline = torch.abs(baseline).clamp(min=1e-5)
#             if self.beta_norm is None:
#                 self.beta_norm = abs_baseline
#             else:
#                 self.beta_norm.lerp_(abs_baseline, 0.1)
#
#             soft_q_mean = -loss_actor
#             improvement = (soft_q_mean - baseline) / self.beta_norm
#             q_var = pred_qs.var(0).mean()
#
#             # new_beta = torch.relu(improvement - self.target_improvement) * q_var / 2.0
#             new_beta = q_var / (2 * self.target_improvement * self.beta_norm)
#             self.beta = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta
#
#         if log:
#             self._info["actor_grad_norm"] = calc_grad_norm(ctx_modules.actor)
#             with torch.no_grad():
#                 raw_q_std = q_var.sqrt()
#                 relative_q_std = raw_q_std / abs_baseline
#                 self._info.update(
#                     **{
#                         "soft_q_mean": soft_q_mean,
#                         "improvement_v": improvement,
#                         "soft_q_std": raw_q_std,
#                         "relative_soft_q_std": relative_q_std,
#                         "beta": self.beta,
#                     }
#                 )
#
#         ctx_modules.actor_optimizer.step()
#
#         return info
#
#     def mb_policy_evaluation(
#         self,
#         ctx_modules,
#         obs,
#         log=False,
#         **kwargs,
#     ):
#         obs = obs.expand((self.num_model_members,) + obs.shape)
#         return super().mb_policy_evaluation(
#             ctx_modules,
#             obs,
#             done=self.done_act.clone(),
#             log=log,
#             prediction_strategy="full",
#         )
#
#     def distribution_rollout(self, num_rollout_samples=None, **rollout_kwargs):
#         rollout_kwargs["prediction_strategy"] = PS_TS1
#         super().distribution_rollout(
#             num_rollout_samples=num_rollout_samples, **rollout_kwargs
#         )
#
#     def _update(self, log=False, **ctx_kwargs):
#         return super()._update(log=log, **dict(ctx_kwargs, prediction_strategy="full"))


# class ExtremeActorGradient(VariationalCriticEnsemble):
#     def _update(self, log=False, **ctx_kwargs):
#         ctx_kwargs["detach"] = True
#         return super()._update(log=log, **ctx_kwargs)
#
#     def actor_loss(
#         self,
#         ctx_modules,
#         batch_sa,
#         batch_mask,
#         log_pis,
#         rewards,
#         last_sa,
#         info,
#         log=False,
#     ):
#         with torch.no_grad():
#             pred_qs = ctx_modules.pred_terminal_q(batch_sa)
#             rewards = torch.cat([rewards, pred_qs])
#             rewards.sub_(ctx_modules.alpha.detach() * log_pis)
#             mve = self.discount_mat[0, :, None, None] * rewards * batch_mask
#             mve = mve.sum(0)
#
#             baseline = mve.mean()
#             abs_baseline = torch.abs(baseline)
#             q_var = mve.var(0).mean().clamp(min=1e-4)
#             if self.beta_norm is None:
#                 self.beta_norm = q_var
#             else:
#                 self.beta_norm.lerp_(q_var, 0.1)
#
#         #####################
#         mc_v_pred = self.evg_rollout(batch_sa, log)
#         # z = (mve - mc_v_pred) / (self.beta_norm * self.beta)
#         # z.clamp_(min=-1.0, max=0.5)
#         # z = (mve - mc_v_pred) / self.beta
#         # max_z = z.detach().max(0, keepdims=True).values
#         # max_z.clamp_(min=-1.0)
#         # loss_actor = z.exp() - (z + 1)
#         loss_actor = -mc_v_pred
#         # loss_actor = torch.exp(z - max_z) - (z + 1) * torch.exp(-max_z.clamp(min=0.0))
#         loss_actor = loss_actor.mean()
#         #####################
#
#         entropy = -log_pis[0].detach().mean()
#         self._info.update(**{KEY_ACTOR_LOSS: loss_actor.detach(), "entropy": entropy})
#
#         if self.learnable_alpha:
#             alpha_loss = -ctx_modules.alpha * (self.target_entropy - entropy)
#             loss_actor += alpha_loss
#
#             self._info.update(
#                 **{
#                     "alpha_loss": alpha_loss.detach(),
#                     "alpha_value": ctx_modules.alpha.detach(),
#                 }
#             )
#
#         # Take a SGD step
#         ctx_modules.actor_optimizer.zero_grad()
#         loss_actor.backward()
#
#         # Beta tuning
#         with torch.no_grad():
#             soft_q_mean = mc_v_pred.mean()
#             improvement = (soft_q_mean - baseline) / self.beta_norm
#
#             # new_beta = torch.relu(improvement - self.target_improvement) * q_var / 2.0
#             new_beta = q_var / (2 * self.target_improvement * self.beta_norm)
#             self.beta = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta
#
#         if log:
#             self._info["actor_grad_norm"] = calc_grad_norm(ctx_modules.actor)
#             with torch.no_grad():
#                 raw_q_std = q_var.sqrt()
#                 relative_q_std = raw_q_std / abs_baseline
#                 self._info.update(
#                     **{
#                         "soft_q_mean": soft_q_mean,
#                         "improvement_v": improvement,
#                         "soft_q_std": raw_q_std,
#                         "relative_soft_q_std": relative_q_std,
#                         "beta": self.beta,
#                     }
#                 )
#
#         ctx_modules.actor_optimizer.step()
#
#         return info
#
#     def evg_rollout(self, batch_sa, log=False):
#         if 0 < self.rollout_horizon:
#             buffer = self.rollout_buffer
#         else:
#             buffer = self.replay_buffer
#
#         ctx_modules = ContextModules(
#             self.actor,
#             self.actor_optimizer,
#             None,
#             None,
#             self.pred_evg_terminal_q,
#             self.critic,
#             self.critic_target,
#             self.critic_optimizer,
#             self.alpha,
#             self.model_step_context,
#             buffer,
#             False,
#         )
#         obs = batch_sa[0, :, : self.dim_obs]
#         action, log_pi, _ = self.sample_action(ctx_modules.actor, obs, log=log)
#         (
#             obss,
#             actions,
#             log_pis,
#             rewards,
#             masks,
#             info,
#         ) = self.rollout(
#             ctx_modules,
#             obs,
#             action,
#             self.training_rollout_horizon,
#             mve_horizon=self.mve_horizon,
#             log_pi=log_pi.squeeze(-1),
#             log=log,
#             prediction_strategy=PS_INF,
#         )
#         last_sa = torch.cat([obss[-1], actions[-1]], -1)
#         log_pis = torch.stack(log_pis)
#         rewards = torch.stack(rewards)
#         masks = torch.stack(masks).float()
#         mc_v_pred = self._actor_loss(ctx_modules, log_pis, last_sa, rewards, masks).t()
#         return mc_v_pred
#
#     def pred_evg_terminal_q(self, obs_action):
#         """
#
#         Args:
#             obs_action: [M, B, D]
#
#         Returns:
#             pred_q: [1, B]
#
#         """
#         m_c = self.num_model_members * self.num_critic_ensemble
#         n = int(self.batch_size / self.num_model_members)
#         obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
#         obs_action = obs_action.view(m_c, n, -1)
#         pred_q = (
#             self.critic(obs_action)
#             .t()
#             .view(self.num_model_members, self.num_critic_ensemble, n)
#             .mean(1)
#             .view(1, self.batch_size)
#         )
#         if self.scaled_critic:
#             pred_q = pred_q * self._q_width + self._q_center
#         return pred_q


# def _extreme_value_expansion(self, batch_sa, log=False):
#     if 0 < self.rollout_horizon:
#         buffer = self.rollout_buffer
#     else:
#         buffer = self.replay_buffer
#
#     ctx_modules = ContextModules(
#         self.actor,
#         self.actor_optimizer,
#         None,
#         None,
#         None,
#         self.critic,
#         self.critic_target,
#         self.critic_optimizer,
#         self.alpha,
#         self.model_step_context,
#         buffer,
#         False,
#     )
#     obs = batch_sa.detach()[0, : self.batch_size, : self.dim_obs]
#     action, log_pi, _ = self.sample_action(ctx_modules.actor, obs, log=log)
#     (
#         obss,
#         actions,
#         log_pis,
#         rewards,
#         masks,
#         info,
#     ) = self.rollout(
#         ctx_modules,
#         obs,
#         action,
#         self.training_rollout_horizon,
#         mve_horizon=self.mve_horizon,
#         log_pi=log_pi.squeeze(-1),
#         log=log,
#         prediction_strategy=PS_INF,
#     )
#     last_sa = torch.cat([obss[-1], actions[-1]], -1)
#     log_pis = torch.stack(log_pis)
#     rewards = torch.stack(rewards)
#     masks = torch.stack(masks).float()
#
#     pred_values = self.extreme_critic(last_sa)
#     if self.scaled_critic:
#         pred_values = pred_values * self._q_width + self._q_center
#     raw_xve = torch.cat([rewards, pred_values.t()])
#     raw_xve[1:, ...] = raw_xve[1:, ...].sub(self.alpha.detach() * log_pis[1:, ...])
#     xve = self.discount_mat.mm(raw_xve * masks).unsqueeze(-1)
#     return masks, log_pis, rewards, last_sa, xve, info
#
#
# class eXtremeValueGradient(VariationalCriticEnsemble):
#     def __init__(
#         self,
#         *args,
#         xve_loss: str = "mse",
#         extreme_loss_ratio: float = 1.0,
#         v_lr_ratio: float = 1.0,
#         xve_horizon: int = 1,
#         target_critic_softmax=False,
#         beta_min: float = 1e-5,
#         beta_update: str = "softmax",
#         **kwargs,
#     ):
#         if not xve_loss in ("mse", "softmax", "gumbel", "gauss"):
#             raise KeyError(f"not {xve_loss} in (mse, softmax, gumbel, gauss)")
#         self.xve_loss = xve_loss
#         self.extreme_loss_ratio = extreme_loss_ratio
#         self.v_lr_ratio = v_lr_ratio
#         self.xve_horizon = xve_horizon
#         self.target_critic_softmax = target_critic_softmax
#         self.beta_min = beta_min
#         self.beta_update = beta_update
#         super().__init__(*args, **kwargs)
#
#         z_ub = 5.0
#         y_ub = np.exp(z_ub) - z_ub - 1
#         slope = np.exp(z_ub) - 1.0
#         bias = -slope * z_ub + y_ub
#
#         self.z_ub = torch.as_tensor(z_ub, dtype=torch.float32, device=self.device)
#         self.slope = torch.as_tensor(slope, dtype=torch.float32, device=self.device)
#         self.bias = torch.as_tensor(bias, dtype=torch.float32, device=self.device)
#
#     def build_critics(self, critic_cfg):
#         super(SVGAgent, self).build_critics(critic_cfg)
#         num_critic_ensemble = int(self.critic.num_members / self.num_model_members)
#
#         if self.weighted_critic:
#             size = (
#                 self.num_model_members,
#                 self.batch_size,
#                 num_critic_ensemble,
#             )
#             weight_rate = torch.ones(size, device=self.device)
#             # self.critic_loss_weight = Exponential(weight_rate).sample()
#             self.critic_loss_weight = torch.poisson(weight_rate)
#
#         self.extreme_critic = instantiate(
#             critic_cfg,
#             num_members=num_critic_ensemble,
#         ).to(self.device)
#         self.extreme_target_critic = instantiate(
#             critic_cfg,
#             num_members=num_critic_ensemble,
#         ).to(self.device)
#         self.extreme_target_critic.load_state_dict(self.extreme_critic.state_dict())
#
#     def init_optimizer(self):
#         actor_params = [
#             {"params": self.actor.parameters()},
#             {"params": [self.raw_alpha], "lr": self.lr_alpha},
#         ]
#
#         critic_params = [
#             {"params": self.critic.parameters()},
#             {
#                 "params": self.extreme_critic.parameters(),
#                 # "weight_decay": self.v_lr_ratio,
#             },
#         ]
#         if hasattr(self, "dynamics_model") and hasattr(
#             self.dynamics_model, "variational_parameters"
#         ):
#             # actor_params.append(
#             #     {
#             #         "params": self.dynamics_model.variational_parameters(),
#             #         # "lr": self.lr * 0.1,
#             #     }
#             # )
#             critic_params.append(
#                 {
#                     "params": self.dynamics_model.variational_parameters(),
#                     "lr": self.lr * self.v_lr_ratio,
#                 }
#             )
#
#         self.actor_optimizer = torch.optim.Adam(actor_params, lr=self.lr)
#         self.critic_optimizer = torch.optim.Adam(critic_params, lr=self.lr)
#
#     # def _update(self, log=False, **ctx_kwargs):
#     #     return super()._update(
#     #         log=log,
#     #         **dict(ctx_kwargs, xve_horizon=-1, prediction_strategy="full"),
#     #     )
#
#     def mb_policy_evaluation(self, ctx_modules, obs, done=None, log=False, **kwargs):
#         (
#             batch_sa,
#             batch_mask,
#             log_pis,
#             rewards,
#             # _,
#             target_q,
#             pred_q,
#             loss_critic,
#             _,
#         ) = super(eXtremeValueGradient, self).mb_policy_evaluation(
#             ctx_modules, obs, done=done, log=log, **kwargs
#         )
#         total_loss_critic = loss_critic
#         info = {}
#
#         with torch.no_grad():
#             pred_q = self._reduce(pred_q, "mean", 2, False)
#             # xve_target = pred_q - self.alpha * log_pis[0]
#             xve_target = pred_q.detach()
#             info[V_PRIOR] = xve_target.mean()
#
#         if 0 <= ctx_modules.xve_horizon:
#             # if self.target_critic_softmax:
#             soft_update_params(
#                 self.extreme_critic,
#                 self.extreme_target_critic,
#                 self.critic_tau,
#             )
#
#             if ctx_modules.xve_horizon == 0:
#                 sa_init = batch_sa.detach()[0, : self.batch_size, :]
#
#                 xve = self.extreme_critic(sa_init)
#                 if self.scaled_critic:
#                     xve = xve * self._q_width + self._q_center
#             else:
#                 ##############
#                 xve, loss_xtreme_td = self.one_step_xve(
#                     batch_sa.detach(), log_pis, ctx_modules.xve_horizon, pred_q, log=log
#                 )
#
#                 #######################
#                 # xve = self.extreme_critic(batch_sa[0, : self.batch_size, :].detach())
#                 # xve = xve * self._q_width + self._q_center
#                 # loss_xtreme_td = xve.detach()[0, 0] * 0
#                 ##############
#                 total_loss_critic = total_loss_critic + loss_xtreme_td
#                 info["loss_xtreme_td"] = loss_xtreme_td.detach()
#
#             loss_xve, info = self.xtreme_loss(xve, xve_target, info, log)
#
#             ##########
#             total_loss_critic = total_loss_critic + loss_xve * self.extreme_loss_ratio
#             with torch.no_grad():
#                 # improvement = diff.mean(2)
#                 # diff_std = improvement.std(0).mean()
#                 self._info.update(
#                     **{
#                         "loss_td": loss_critic.detach(),
#                         KEY_CRITIC_LOSS: total_loss_critic.detach(),
#                         "loss_extreme_value": loss_xve.detach(),
#                     },
#                     **info,
#                 )
#         else:
#             total_loss_critic = loss_critic
#             info[XVE_TARGET] = xve_target
#
#         return (
#             batch_sa,
#             batch_mask,
#             log_pis,
#             rewards,
#             # None,
#             target_q,
#             pred_q,
#             total_loss_critic,
#             info,
#         )
#
#     def one_step_xve(self, batch_sa, log_pis, xve_horizon, pred_q, log=False):
#         sa_init = batch_sa[0, : self.batch_size, :]
#
#         xtreme_q = self.extreme_critic(sa_init)
#         if self.target_critic_softmax:
#             # if self.scaled_critic:
#             #     xtreme_q = xtreme_q * self._q_width
#
#             prior_q = pred_q.detach().mean(0, keepdim=True).t()
#             xtreme_q += prior_q
#         else:
#             if self.scaled_critic:
#                 xtreme_q = xtreme_q * self._q_width + self._q_center
#
#         state, action = torch.tensor_split(sa_init, [self.dim_obs], 1)
#         log_pis = log_pis.detach()
#         log_pi = log_pis[0, 0, :]
#         ctx_modules = ContextModules(
#             actor=self.actor,
#             model_step=self.model_step_context,
#             detach=False,
#         )
#         (
#             obss,
#             actions,
#             log_pis,
#             rewards,
#             masks,
#             info,
#         ) = self.detached_rollout(
#             ctx_modules,
#             state,
#             action,
#             xve_horizon,
#             mve_horizon=self.mve_horizon,
#             log_pi=log_pi,
#             log=log,
#             prediction_strategy=CTX_EXPLR,
#         )
#         masks = torch.stack(masks).float()
#         log_pis = torch.stack(log_pis)
#         rewards = torch.stack(rewards)
#         last_sa = torch.cat([obss[-1], actions[-1]], -1)
#
#         with torch.no_grad():
#             if self.target_critic_softmax:
#                 prior_q = self.pred_q_value(
#                     last_sa[None, ...].repeat((self.num_model_members, 1, 1))
#                 ).mean(0)
#
#                 # prior_q = prior_q.mean(1).t()
#                 # prior_q = self.pred_target_q_value(
#                 #     last_sa[None, ...].repeat((self.num_model_members, 1, 1))
#                 # )
#                 # prior_q = prior_q.mean(1).t()
#
#                 xtrm_q_diff = self.extreme_target_critic(last_sa)
#                 # if self.scaled_critic:
#                 #     xtrm_q_diff = xtrm_q_diff * self._q_width
#                 pred_qs = xtrm_q_diff + prior_q
#                 pred_qs = self._reduce(pred_qs, dim=1, keepdim=True).t()
#                 if self.bounded_critic:
#                     pred_qs = self._q_ub - torch.relu(self._q_ub - pred_qs)
#                     pred_qs = self._q_lb + torch.relu(pred_qs - self._q_lb)
#             else:
#                 # pred_qs = self.extreme_critic(last_sa)
#                 pred_qs = self.extreme_target_critic(last_sa)
#                 pred_qs = self._reduce(pred_qs, dim=1, keepdim=True).t()
#
#                 if self.scaled_critic:
#                     pred_qs = pred_qs * self._q_width + self._q_center
#                 if self.bounded_critic:
#                     pred_qs = self._q_ub - torch.relu(self._q_ub - pred_qs)
#                     pred_qs = self._q_lb + torch.relu(pred_qs - self._q_lb)
#
#         rewards = torch.cat([rewards, pred_qs])
#         rewards[1:, ...] -= self.alpha.detach() * log_pis[1:, ...]
#         xve = self.discount_mat[:1, : len(rewards)].mm(rewards * masks).t()
#         td_target = xve.detach()
#         loss_xtreme_td = F.mse_loss(xtreme_q, td_target, reduction="none")
#         if not self.target_critic_softmax:
#             with torch.no_grad():
#                 extreme_target = (
#                     self.beta
#                     * (
#                         torch.logsumexp(pred_q / self.beta, 0)
#                         - np.log(self.num_model_members)
#                     )[:, None]
#                 )
#             loss_xtreme_td += F.mse_loss(xtreme_q, extreme_target, reduction="none")
#         if self.weighted_critic:
#             loss_xtreme_td *= self.critic_loss_weight[0]
#         loss_xtreme_td = loss_xtreme_td.sum(1).mean()
#
#         # xve = xve - self.alpha.detach() * log_pis[0, :, None]
#
#         self._info[XTRME_Q] = xtreme_q.detach().mean()
#         return xve, loss_xtreme_td
#
#     def detached_rollout(
#         self,
#         ctx_modules,
#         obs,
#         action,
#         rollout_horizon: int,
#         done=None,
#         mve_horizon=None,
#         log_pi=None,
#         log=False,
#         prediction_strategy=None,
#         **kwargs,
#     ):
#         done = self.done.clone() if done is None else done
#         obs = obs.detach()
#         obss, actions, log_pis, rewards, masks, info = (
#             [obs.detach()],
#             [action],
#             [log_pi],
#             [],
#             [~done],
#             [],
#         )
#         for step in range(rollout_horizon):
#             # Sample action
#             if 0 < step:
#                 with torch.no_grad():
#                     action, log_pi, pi = self.sample_action(
#                         ctx_modules.actor, obs, **kwargs
#                     )
#
#                 log_pis.append(log_pi.squeeze(-1))
#                 # if step < mve_horizon:
#                 obss.append(obs)
#                 actions.append(action)
#
#             obs, rewards_i, done_i, info_s = ctx_modules.model_step(
#                 action,
#                 obs,
#                 prediction_strategy=prediction_strategy,
#                 log=log,
#             )
#             obs = obs.detach()
#
#             rewards.append(rewards_i.squeeze(-1))
#             done |= done_i.squeeze(-1)
#             masks.append(~done)
#             self._info.update(**info_s)
#
#         # Convert a stack of dict into a dict of stacked model states
#         obss.append(obs)
#
#         # Terminal condition
#         with torch.no_grad():
#             action, log_pi, pi = self.sample_action(ctx_modules.actor, obs, **kwargs)
#             actions.append(action)
#
#         log_pis.append(log_pi.squeeze(-1))
#
#         return obss, actions, log_pis, rewards, masks, info
#
#     def xtreme_loss(self, xve, xve_target, info, log):
#         q_epi = xve_target.detach().std(0).mean()
#         ##########
#         if self.xve_loss == "softmax":
#             with torch.no_grad():
#                 extreme_target = self.beta * (
#                     torch.logsumexp(xve_target / self.beta, 0, keepdim=True)
#                     - np.log(self.num_model_members)
#                 )
#             loss_xve = F.mse_loss(xve.unsqueeze(0), extreme_target.unsqueeze(2))
#             if log:
#                 with torch.no_grad():
#                     diff = torch.mean(extreme_target.unsqueeze(2) - xve.unsqueeze(0))
#                     info["extreme_diff"] = diff
#         elif self.xve_loss == "gauss":
#             extreme_target = xve_target.mean(0) + xve_target.var(0) / (2 * self.beta)
#             loss_xve = F.mse_loss(xve.squeeze(1), extreme_target)
#         elif self.xve_loss == "gumbel":
#             # z = diff / self.beta
#             # loss_xve = torch.exp(z) - z - 1.0
#             # large_z = self.z_ub < z
#             # loss_xve[large_z] = self.slope * z[large_z] + self.bias
#             # loss_xve = loss_xve.sum(2).mean()
#
#             # # max_z = z.max()
#             # # diff_z = soft_bound(z - max_z, 2.0)
#             # # neg_max_z = soft_bound(-max_z, 2.0)
#             # # loss_xve = torch.exp(diff_z) - (z + 1) * torch.exp(neg_max_z)
#             # # loss_xve = loss_xve.sum(2).mean() * max_z.clamp(max=2.0).exp()
#             # # info["max_z"] = max_z.detach
#
#             dist = torch.distributions.gumbel.Gumbel(xve_target.unsqueeze(2), self.beta)
#             loss_xve = -dist.log_prob(xve.unsqueeze(0))
#             loss_xve = soft_bound(loss_xve, ub=1e3)
#             loss_xve = loss_xve.sum(2).mean()
#         else:
#             assert self.xve_loss == "mse"
#             diff = xve_target.unsqueeze(2) - xve.unsqueeze(0)
#             loss_xve = torch.square(diff).sum(2).mean()
#
#         info[Q_EPI_STD] = q_epi
#         return loss_xve, info
#
#     def extreme_value_expansion(self, log_pis, batch_sa, rewards, masks, info):
#         idx = torch.randint(low=0, high=2, size=(), device=self.device)
#         sa, last_sa = torch.tensor_split(batch_sa[idx], [-self.batch_size], 0)
#         log_pis = log_pis[:, idx, :]
#         ##########
#         rewards = rewards[:, idx, :]
#         if hasattr(self.dynamics_model, "variational_parameters"):
#             # sa = batch_sa[idx, : -self.batch_size, :]
#             reward_x = self.dynamics_model.internal_reward(sa, denormalize=True)[0]
#             rewards += reward_x.view(self.training_rollout_horizon, self.batch_size)
#         ##########
#         masks = masks[:, idx, :]
#
#         if self.target_critic_softmax:
#             prior_q = self.pred_q_value(
#                 last_sa[None, ...].repeat((self.num_model_members, 1, 1))
#             )[idx]
#             # prior_q = self._reduce(prior_q, self.actor_reduction, 1, keepdim=True).t()
#             pred_values = self.extreme_critic(last_sa)
#             # if self.scaled_critic:
#             #     pred_values = pred_values * self._q_width
#             pred_values += prior_q
#             pred_values = self._reduce(
#                 pred_values, self.actor_reduction, 1, keepdim=True
#             ).t()
#
#             # pred_values += prior_q
#         else:
#             pred_values = self.extreme_critic(last_sa).mean(1, keepdim=True).t()
#             if self.scaled_critic:
#                 pred_values = pred_values * self._q_width + self._q_center
#         raw_xve = torch.cat([rewards, pred_values])
#         raw_xve = raw_xve.sub(self.alpha.detach() * log_pis)
#         xve = self.discount_mat.mm(raw_xve * masks).unsqueeze(-1)
#         return xve
#
#     def _actor_loss(
#         self, ctx_modules, log_pis, batch_sa, rewards, masks, info, log=False
#     ):
#         xve = self.extreme_value_expansion(log_pis, batch_sa, rewards, masks, info)
#         return -xve.mean()
#
#     def actor_loss(
#         self,
#         ctx_modules,
#         batch_sa,
#         batch_mask,
#         log_pis,
#         rewards,
#         info,
#         log=False,
#     ):
#         super(VariationalCriticEnsemble, self).actor_loss(
#             ctx_modules, batch_sa, batch_mask, log_pis, rewards, info, log
#         )
#         # Beta tuning
#         with torch.no_grad():
#             baseline = self._info[V_PRIOR]
#             q_std = self._info[Q_EPI_STD]
#             abs_baseline = torch.abs(baseline).clamp(min=1e-5)
#             if self.beta_norm is None:
#                 self.beta_norm = abs_baseline
#                 self.beta_std = q_std
#             else:
#                 self.beta_norm.lerp_(abs_baseline, self.lr_beta)
#                 self.beta_std.lerp_(q_std, self.lr_beta)
#
#             improvement = (self._info[XTRME_Q] - baseline) / self.beta_norm
#
#             if self.beta_update == "gumbel":
#                 new_beta = self.beta_std / self.target_improvement
#             elif self.beta_update == "relu":
#                 new_beta = (
#                     torch.relu(improvement - self.target_improvement) * self.beta_std
#                 ).clamp(min=self.beta_min)
#             else:
#                 q_var = self.beta_std.square()
#                 new_beta = q_var / (2 * self.target_improvement * self.beta_norm)
#
#             # self.beta = new_beta
#             self.beta = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta
#             # self.beta.clamp_(min=self.beta_min)
#             self._info.update(
#                 **{
#                     "relative_soft_q_std": self.beta_std / self.beta_norm,
#                     "improvement_v": improvement,
#                     # : relative_q_std,
#                     "beta": self.beta,
#                 }
#             )
#
#     def pred_terminal_q(self, obs_action):
#         exq = self.extreme_critic(obs_action)
#         if self.scaled_critic:
#             exq = exq * self._q_width + self._q_center
#         exq = self._reduce(exq, self.actor_reduction)
#         return exq.t()
#
#     @contextmanager
#     def policy_evaluation_context(self, detach=False, xve_horizon=None, **kwargs):
#         if 0 < self.rollout_horizon:
#             buffer = self.rollout_buffer
#         else:
#             buffer = self.replay_buffer
#
#         if xve_horizon is None:
#             # For call in update_critic
#             xve_horizon = self.xve_horizon
#         context_modules = ExtContextModules(
#             self.actor,
#             self.actor_optimizer,
#             self.pred_q_value,
#             self.pred_target_q_value,
#             self.pred_terminal_q,
#             self.critic,
#             self.critic_target,
#             self.critic_optimizer,
#             self.alpha,
#             self.model_step_context,
#             buffer,
#             detach,
#             xve_horizon=xve_horizon,
#         )
#         try:
#             yield context_modules
#         finally:
#             pass
#
#     @property
#     def description(self) -> str:
#         # Print model's architecture
#         num_total_params = 0
#         print_str = ""
#         for m in (self.dynamics_model, self.critic, self.extreme_critic, self.actor):
#             num_params = sum(np.prod(p.shape) for p in m.parameters())
#             num_total_params += num_params
#             print_str += f"{str(m)}\n"
#             print_str += f"#Params for {m._get_name()}: {num_params}\n"
#
#         print_str += f"#Params in total: {num_total_params}\n"
#         return print_str
#
#
# class MultieXtremeValueGradient(eXtremeValueGradient):
#     def one_step_xve(self, batch_sa, log_pis, xve_horizon, pred_q, log=False):
#         sa_init = batch_sa[0, : self.batch_size, :]
#
#         xtreme_q = self.extreme_critic(sa_init)
#         xtreme_q = xtreme_q * self._q_width + self._q_center
#
#         state, action = torch.tensor_split(sa_init, [self.dim_obs], 1)
#         log_pis = log_pis.detach()
#         log_pi = log_pis[0, 0, :]
#         ctx_modules = ContextModules(
#             actor=self.actor,
#             model_step=self.model_step_context,
#             detach=False,
#         )
#         (
#             obss,
#             actions,
#             log_pis,
#             rewards,
#             masks,
#             info,
#         ) = self.detached_rollout(
#             ctx_modules,
#             state,
#             action,
#             xve_horizon,
#             mve_horizon=self.mve_horizon,
#             log_pi=log_pi,
#             log=log,
#             prediction_strategy=CTX_EXPLR,
#         )
#         masks = torch.stack(masks).float()
#         log_pis = torch.stack(log_pis)
#         rewards = torch.stack(rewards)
#         last_sa = torch.cat([obss[-1], actions[-1]], -1)
#
#         with torch.no_grad():
#             if self.target_critic_softmax:
#                 pred_qs = self.extreme_target_critic(last_sa)
#             else:
#                 pred_qs = self.extreme_critic(last_sa)
#
#             pred_qs = pred_qs * self._q_width + self._q_center
#             pred_qs = self._q_ub - torch.relu(self._q_ub - pred_qs)
#             pred_qs = self._q_lb + torch.relu(pred_qs - self._q_lb)
#
#         rewards = rewards[..., None].repeat(1, 1, self.num_critic_ensemble)
#         rewards = torch.cat([rewards, pred_qs[None, ...]], dim=0)
#         rewards[1:, ...] -= self.alpha.detach() * log_pis[1:, ..., None]
#         gammas = self.discount_mat[0, : len(rewards), None, None]
#         xve = torch.sum(gammas * (rewards * masks[..., None]), 0, keepdim=True)
#
#         extreme_target = self._reduce(xve.detach()[0], "min", 1)
#         loss_xtreme_td = F.mse_loss(xtreme_q, extreme_target, reduction="none")
#         loss_xtreme_td *= self.critic_loss_weight[0]
#         loss_xtreme_td = loss_xtreme_td.sum(1).mean()
#
#         xve = torch.cat([xtreme_q[None], xve], 0)
#
#         self._info[XTRME_Q] = xtreme_q.detach().mean()
#         return xve, loss_xtreme_td
#
#     def xtreme_loss(self, xve, xve_target, info, log):
#         q_epi = xve_target.detach().std(0).mean()
#         ##########
#         if self.xve_loss == "softmax":
#             with torch.no_grad():
#                 xve_mean = xve.mean((0, 2), keepdims=True)
#                 diff = xve_target[..., None] - xve_mean
#                 extreme_target = (
#                     self.beta
#                     * (
#                         torch.logsumexp(diff / self.beta, 0, keepdim=True)
#                         - np.log(self.num_model_members)
#                     )
#                     + xve_mean
#                 )
#
#             #     extreme_target = self.beta * (
#             #         torch.logsumexp(xve_target / self.beta, 0, keepdim=True)
#             #         - np.log(self.num_model_members)
#             #     )
#             # loss_xve = F.mse_loss(xve, extreme_target[..., None], reduction="none")
#             loss_xve = F.mse_loss(xve, extreme_target, reduction="none")
#             loss_xve = loss_xve.sum(0) * self.critic_loss_weight[0]
#             loss_xve = loss_xve.mean(0).sum()
#             if log:
#                 with torch.no_grad():
#                     extreme_diff = (extreme_target - xve) / xve_target.mean(
#                         0, keepdims=True
#                     )[..., None].abs()
#                     info["extreme_diff"] = extreme_diff.mean()
#         elif self.xve_loss == "gauss":
#             extreme_target = xve_target.mean(0) + xve_target.var(0) / (2 * self.beta)
#             loss_xve = F.mse_loss(xve.squeeze(1), extreme_target)
#         elif self.xve_loss == "gumbel":
#             # z = diff / self.beta
#             # loss_xve = torch.exp(z) - z - 1.0
#             # large_z = self.z_ub < z
#             # loss_xve[large_z] = self.slope * z[large_z] + self.bias
#             # loss_xve = loss_xve.sum(2).mean()
#
#             # # max_z = z.max()
#             # # diff_z = soft_bound(z - max_z, 2.0)
#             # # neg_max_z = soft_bound(-max_z, 2.0)
#             # # loss_xve = torch.exp(diff_z) - (z + 1) * torch.exp(neg_max_z)
#             # # loss_xve = loss_xve.sum(2).mean() * max_z.clamp(max=2.0).exp()
#             # # info["max_z"] = max_z.detach
#
#             dist = torch.distributions.gumbel.Gumbel(xve_target.unsqueeze(2), self.beta)
#             loss_xve = -dist.log_prob(xve.unsqueeze(0))
#             loss_xve = soft_bound(loss_xve, ub=1e3)
#             loss_xve = loss_xve.sum(2).mean()
#         else:
#             assert self.xve_loss == "mse"
#             diff = xve_target.unsqueeze(2) - xve.unsqueeze(0)
#             loss_xve = torch.square(diff).sum(2).mean()
#
#         info[Q_EPI_STD] = q_epi
#         return loss_xve, info


########## in VariationalValueGradient._actor_loss
# if self.separate_actor:
#     pi_q = self.behavior_actor(sas[:, : self.dim_obs])
#     # pi_q = self.actor(sas[:, : self.dim_obs])
#     log_pis_q = pi_q.log_prob(sas[:, self.dim_obs :]).view(
#         self.training_rollout_horizon + 1, self.batch_size
#     )
#     kl_pi = log_pis_q - log_pis.detach()
#     alpha = self.alpha.detach()
#
#     mves = torch.cat([rewards, pred_q]) - alpha * log_pis
#     vmves = torch.cat([reward_x, pred_q_x]) - alpha * log_pis_q
#
#     vmves = kl_pi.exp().clamp(max=2.0) * vmves
#     # vmves *= kl_pi.exp()
#     vmve = self.discount_mat.mm(vmves * masks)
#     mve = self.discount_mat.mm(mves * masks)
#     actor_loss = -(vmve + mve).mean()
# else:

#### Improvement
# if self.beta_update == "square":
#     raw_beta = (
#             torch.square((improvement - self.target_improvement).clamp(min=0.0))
#             * self.beta_std
#     )
# elif self.beta_update == "exp":
#     log_eps = np.log(self.beta_min)
#     raw_beta = torch.exp(
#         log_eps
#         + improvement.clamp(min=0.0)
#         * (self.beta_std.log() - log_eps).clamp(min=self.beta_min)
#         / self.target_improvement
#     )
