from contextlib import contextmanager, nullcontext
from collections import defaultdict, deque
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from erl_lib.agent.svg import SVGAgent, ContextModules

from erl_lib.base import KEY_CRITIC_LOSS, KEY_ACTOR_LOSS, Q_MEAN
from erl_lib.util.misc import calc_grad_norm, soft_update_params
from erl_lib.agent.model_based.modules.gaussian_mlp import PS_TS1, PS_INF


Q_EPI_STD = "q_epi_std"
Q_DIFF_STD = "q_diff_std"
V_PRIOR = "v_prior"
IMPRV = "improvement"
REL_IMPRV = "relative_improvement"
XVE_TARGET = "xve_target"
XTRME_Q = "q_extreme"


class VariationalCriticEnsemble(SVGAgent):
    def __init__(
        self,
        *args,
        target_improvement=0.1,
        beta_init=1.0,
        lr_beta=0.2,
        separate_alpha: bool = False,
        extreme_epi_q: bool = False,
        **kwargs,
    ):
        self.num_model_members = kwargs["dynamics_model"].num_members
        super(VariationalCriticEnsemble, self).__init__(*args, **kwargs)
        # self.num_model_members = self.dynamics_model.num_members
        if self.num_critic_ensemble % self.num_model_members != 0:
            raise ValueError(
                f"{self.num_critic_ensemble} != C x {self.num_model_members})"
            )
        self.num_critic_ensemble = int(self.critic.num_members / self.num_model_members)
        size_exp = self.num_model_members, self.batch_size
        self.done_act = self.done.unsqueeze(0).expand(size_exp)
        self.raw_beta = torch.nn.Parameter(
            torch.tensor(np.log(beta_init), dtype=torch.float32, device=self.device),
            requires_grad=False,
        )
        self.lr_beta = lr_beta
        self.extreme_epi_q = extreme_epi_q
        self.target_improvement = max(target_improvement, 1e-8)
        self.beta_norm = None
        self.last_baseline = None
        self.beta_std = None
        self.baseline_que = deque(maxlen=int(1 / self.lr_norm))

        self.init_optimizer()

    def eval_rollout(
        self,
        batch_sa,
        log_pis,
        batch_masks,
        rewards,
        alpha,
        last_sa,
        discounts,
        critic,
        critic_target,
    ):
        """

        Args:
            batch_sa: [M, (L+1)B, D]
            log_pis: [L+1, M, B]
            batch_masks: [L+1, M, B]
            rewards: [L, M, B]
            last_sa: [M, B, D]
            discounts: [L, L+1]

        Returns:
            target_values: [L, B, M, 1]
            pred_values: [L, B, M, C]

        """
        with torch.no_grad():
            if self.bounded_critic or self.scaled_critic:
                reward_pi = rewards - alpha * log_pis[:-1, :]
                reward_lb, reward_ub = torch.quantile(reward_pi, self.q_th)
                self.update_critic_bound(reward_lb, reward_ub)
            q_values = critic_target(last_sa)
            target_rewards = torch.cat(
                [rewards.transpose(1, 2)[..., None], q_values]
            )  # [L+1, B, M, 1]
            target_rewards.sub_(alpha.detach() * log_pis.transpose(1, 2)[..., None])
            target_values = target_rewards * batch_masks.transpose(1, 2)[..., None]
            target_values = torch.sum(
                discounts[..., None, None, None] * target_values[None], 1
            )

        pred_values = critic(batch_sa.detach())  # [L, B, M, C]
        horizon = discounts.shape[1] - discounts.shape[0]
        mask_ensemble = batch_masks[:-horizon].transpose(1, 2)[..., None]
        pred_values *= mask_ensemble
        return target_values, pred_values

    def pred_q_value(self, obs_action, batch_size=None):
        """

        Args:
            obs_action: [M, B, D]

        Returns:
            pred_q: [1, B, M, C]

        """
        if batch_size is None:
            batch_size = self.batch_size
        shape = (-1, batch_size, self.num_model_members, self.num_critic_ensemble)
        obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
        pred_q = self.critic(obs_action).contiguous().view(shape)
        if self.scaled_critic:
            pred_q = pred_q * self._q_width + self._q_center
        return pred_q

    def pred_target_q_value(self, obs_action):
        """

        Args:
            obs_action: [M, B, D]
        Returns:
            target_q: [1, B, M, 1]
        """
        # -> [MC, B, D]
        obs_action = obs_action.repeat_interleave(self.num_critic_ensemble, 0)
        shape = (1, self.batch_size, self.num_model_members, self.num_critic_ensemble)
        target_q = self.critic_target(obs_action).view(shape)
        target_q = self._reduce(target_q, dim=3, keepdim=True)
        if self.scaled_critic:
            target_q = target_q * self._q_width + self._q_center
        if self.bounded_critic:
            target_q = self._q_ub - torch.relu(self._q_ub - target_q)
            target_q = self._q_lb + torch.relu(target_q - self._q_lb)
        return target_q

    def pred_terminal_q(self, obs_action):
        pred_qss = self.pred_q_value(obs_action)
        pred_qs = self._reduce(pred_qss, self.actor_reduction, dim=2, keepdim=False)
        return pred_qs.unsqueeze(0)

    def actor_loss(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        """

        Args:
            batch_sa: [M, B, D]
            batch_mask:
            log_pis: [L+1, M, B]
            rewards:
            last_sa:
        """
        pred_qs = ctx_modules.pred_terminal_q(batch_sa)
        rewards = torch.cat([rewards, pred_qs])
        rewards.sub_(ctx_modules.alpha.detach() * log_pis)
        pred_qs = self.discount_mat[0, :, None, None] * rewards * batch_mask
        pred_qs = pred_qs.sum(0)

        softmax_q = pred_qs
        loss_actor = -softmax_q.mean()

        entropy = -log_pis[0].detach().mean()
        self._info.update(**{KEY_ACTOR_LOSS: loss_actor.detach(), "entropy": entropy})

        if self.learnable_alpha:
            alpha_loss = -ctx_modules.alpha * (self.target_entropy - entropy)
            loss_actor += alpha_loss

            self._info.update(
                **{
                    "alpha_loss": alpha_loss.detach(),
                    "alpha_value": ctx_modules.alpha.detach(),
                }
            )

        # Take a SGD step
        ctx_modules.actor_optimizer.zero_grad()
        loss_actor.backward()

        # Beta tuning
        with torch.no_grad():
            baseline = pred_qs.mean()
            abs_baseline = torch.abs(baseline).clamp(min=1e-5)
            if self.beta_norm is None:
                self.beta_norm = abs_baseline
            else:
                self.beta_norm.lerp_(abs_baseline, 0.1)

            soft_q_mean = -loss_actor
            improvement = (soft_q_mean - baseline) / self.beta_norm
            q_var = pred_qs.var(0).mean()

            # new_beta = torch.relu(improvement - self.target_improvement) * q_var / 2.0
            new_beta = q_var / (2 * self.target_improvement * self.beta_norm)
            self.beta = (1 - self.lr_beta) * self.beta + self.lr_beta * new_beta

        if log:
            self._info["actor_grad_norm"] = calc_grad_norm(ctx_modules.actor)
            with torch.no_grad():
                raw_q_std = q_var.sqrt()
                relative_q_std = raw_q_std / abs_baseline
                self._info.update(
                    **{
                        "soft_q_mean": soft_q_mean,
                        "improvement_v": improvement,
                        "soft_q_std": raw_q_std,
                        "relative_soft_q_std": relative_q_std,
                        "beta": self.beta,
                    }
                )

        ctx_modules.actor_optimizer.step()

    def mb_policy_evaluation(
        self,
        ctx_modules,
        obs,
        log=False,
        **kwargs,
    ):
        obs = obs.expand((self.num_model_members,) + obs.shape)
        returns = super().mb_policy_evaluation(
            ctx_modules,
            obs,
            done=self.done_act.clone(),
            log=log,
            prediction_strategy="full",
        )
        target_q = returns[4]
        if not self.extreme_epi_q:
            if self.sample_std:
                self._info[Q_EPI_STD] = target_q[0].detach().std((0, 1)).mean()
            else:
                self._info[Q_EPI_STD] = target_q[0].detach().std(1).mean()

        return returns

    def distribution_rollout(self, **rollout_kwargs):
        rollout_kwargs["prediction_strategy"] = PS_TS1
        super().distribution_rollout(**rollout_kwargs)

    def _update(self, log=False, **ctx_kwargs):
        return super()._update(log=log, **dict(ctx_kwargs, prediction_strategy="full"))


class VariationalValueGradient(VariationalCriticEnsemble):
    trange_kv = dict(Beta="beta", **SVGAgent.trange_kv)
    beta_optimizer = None

    def __init__(
        self,
        actor,
        *args,
        v_lr_ratio: float = 1.0,
        lr_norm: float = 0.1,
        beta_min: float = 1e-8,
        beta_max: float = 1e4,
        beta_update: str = "std",
        discount_beta: bool = False,
        mvve_improvement: bool = False,
        max_beta_inc: float = 2.0,
        mean_improve: bool = True,
        kl_improve: bool = False,
        smoothed_baseline: bool = False,
        sample_std=False,
        episode_wise_norm: bool = True,
        **kwargs,
    ):
        self.v_lr_ratio = v_lr_ratio
        self.lr_norm = lr_norm
        self.beta_min = beta_min
        self.beta_log_min = np.log(beta_min)
        self.beta_log_max = np.log(beta_max)
        self.beta_update = beta_update
        self.discount_beta = discount_beta
        # self.mvve_improvement = mvve_improvement
        self.max_beta_inc = max_beta_inc
        self.mean_improve = mean_improve
        self.kl_improve = kl_improve
        self.smoothed_baseline = smoothed_baseline
        self.sample_std = sample_std
        self.episode_wise_norm = episode_wise_norm

        super().__init__(*args, actor=actor, **kwargs)
        idx = (
            torch.arange(self.num_model_members)
            .repeat_interleave(self.batch_size // self.num_model_members)
            .to(self.device)
        )
        self.idx_mask = (
            F.one_hot(idx, self.num_model_members)
            .to(torch.float32)
            .t()[None]
            .repeat_interleave(self.training_rollout_horizon + 1, dim=0)
        )
        self.sa_idx_mask = F.one_hot(
            idx.repeat(self.training_rollout_horizon + 1), self.num_model_members
        ).t()[..., None]

    def _act_separate(self, obs, sample=False):
        with torch.no_grad():
            if isinstance(obs, np.ndarray):
                obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
            if self.normalize_input:
                obs = self.input_normalizer.normalize(obs)

            if sample:
                self.actor.is_optimistic = True
                dist = self.actor(obs)
                action = dist.sample()
            else:
                self.actor.is_optimistic = False
                dist = self.actor(obs)
                action = dist.mean
            action = action.cpu().numpy()
        return action

    def _actor_loss(self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False):
        return self._single_actor_loss(
            ctx_modules, log_pis, batch_sa, rewards, masks, log
        )

    def _separate_actor_loss(
        self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False
    ):
        pred_qs = self.pred_q_value(batch_sa[..., -self.batch_size :, :])
        pred_qs = self._reduce(pred_qs, self.actor_reduction, dim=-1, keepdim=False)[
            None
        ]
        expanded_rewards = torch.cat([rewards, pred_qs])
        expanded_rewards.sub_(ctx_modules.alpha.detach() * log_pis)
        mve = torch.sum(
            (self.discount_mat[0, :, None, None] * expanded_rewards * masks), 0
        )
        optimized_actor_loss = -mve.mean()

        with self.policy_evaluation_context(explore=True) as ctx_modules:
            obs = batch_sa[0, : self.batch_size, : self.dim_obs]
            action, log_pi, _ = self.sample_action(ctx_modules.actor, obs, log=log)
            (
                obss,
                actions,
                log_pis,
                rewards,
                masks,
            ) = self.rollout(
                ctx_modules,
                obs,
                action,
                self.training_rollout_horizon,
                done=self.done,
                mve_horizon=self.mve_horizon,
                log_pi=log_pi.squeeze(-1),
                log=log,
                prediction_strategy=PS_INF,
            )
        masks = torch.stack(masks).float()
        batch_sa = torch.cat([torch.cat(obss, -2), torch.cat(actions, -2)], -1)
        log_pis = torch.stack(log_pis)
        rewards = torch.stack(rewards)

        sa, last_sa = torch.tensor_split(batch_sa, [-self.batch_size], 0)
        exp_actor_loss = self._add_x_rewards(sa, last_sa, log_pis, rewards, masks, log)

        with torch.no_grad():
            self._info[Q_MEAN] = mve.mean()
        return optimized_actor_loss + exp_actor_loss

    def _single_actor_loss(
        self, ctx_modules, log_pis, batch_sa, rewards, masks, log=False
    ):
        """

        Args:
            log_pis:
            batch_sa: [M, (L+1)B, S+A]
            rewards: [L, M, B]
            masks:
        """
        log_pis = torch.sum(log_pis * self.idx_mask, 1)
        masks = torch.sum(masks * self.idx_mask, 1)
        rewards = torch.sum(rewards * self.idx_mask[:-1], 1)
        sas = torch.sum(batch_sa * self.sa_idx_mask, 0)
        sa, last_sa = torch.tensor_split(sas, [-self.batch_size], 0)
        return self._add_x_rewards(sa, last_sa, log_pis, rewards, masks, log)

    def _add_x_rewards(self, sa, last_sa, log_pis, rewards, masks, log=False):
        # reward_x, kl = self.dynamics_model.internal_reward(
        #     sa, self.sa_idx_mask[:, : -self.batch_size, :]
        # )
        # kl_term = self.beta * kl
        # reward_x -= kl_term
        # if not self.kl_improve:
        #     reward_x += kl_term.detach()
        # reward_x = reward_x.view(self.training_rollout_horizon, self.batch_size)

        base_scale, variational_scale, kl = self.dynamics_model.internal_reward(
            sa, self.sa_idx_mask[:, : -self.batch_size, :]
        )
        reward_x = base_scale * variational_scale
        kl_term = self.beta * kl
        reward_x -= kl_term
        if not self.kl_improve:
            reward_x += kl_term.detach()
        reward_x = reward_x.view(self.training_rollout_horizon, self.batch_size)
        # rewards = rewards + reward_x
        # rewards.add_(reward_x)

        pred_q_x, pred_q = self.pred_terminal_q(last_sa, log=log)

        # Truncated extreme rewards
        trunc_rews_x = torch.cat([rewards + reward_x, pred_q_x])
        trunc_rews_x.sub_(self.alpha.detach() * log_pis)

        vmve = self.discount_mat.mm(trunc_rews_x * masks)
        actor_loss = -vmve.mean()

        with torch.no_grad():
            # if not self.separate_actor:
            rewards = torch.cat([rewards, pred_q]).clone()
            rewards[1:] -= self.alpha * log_pis[1:]
            baseline = self.discount_mat.mm(rewards * masks)[0]
            sdqe = vmve[0] + self.alpha * log_pis[0, :]
            if self.mean_improve:
                self._info[Q_MEAN] = baseline.mean()
                self._info[XTRME_Q] = sdqe.mean()
            else:
                self._info[Q_MEAN] = baseline
                self._info[XTRME_Q] = sdqe

            if self.extreme_epi_q:
                self._info[Q_EPI_STD] = sdqe.std()

            if log:
                self._info.update(
                    **{
                        "reward_bonus": reward_x.mean(),
                        "reward_scale": variational_scale.mean(),
                        "kl": kl.mean(),
                        "kl_term": kl_term.mean(),
                        # "state_entropy": state_entropy.mean(),
                    }
                )
        return actor_loss

    def pred_terminal_q(self, obs_action, log=False, batch_size=None):
        """

        Args:
            obs_action: [B, D]

        Returns:
            pred_q: [B, 1]

        """
        if batch_size is None:
            batch_size = self.batch_size
        self.critic.train(False)

        q_mcb = (
            self.critic(obs_action)
            .t()
            .view(self.num_model_members, self.num_critic_ensemble, batch_size)
        )
        q_mb = q_mcb.mean(1)

        q_mu = q_mb.mean(0)[..., None]
        q_std = q_mb.std(0)[..., None]
        #############
        # q_bonus, q_kl = self.dynamics_model.variational_terms(
        #     q_mb, obs_action, is_terminal=True
        # )
        # q_bonus -= q_mu.detach()
        #############
        q_v_scale, q_kl = self.dynamics_model.variational_terms(
            q_mu, q_std, obs_action, is_terminal=True
        )
        if self.actor_reduction == "ub":
            base_scale = q_mcb.std((0, 1))[..., None]
        else:
            base_scale = q_std

        q_bonus = base_scale * q_v_scale
        #############

        if self.scaled_critic:
            q_bonus *= self._q_width
            q_mu = q_mu * self._q_width + self._q_center

        # q_kl_term = self.beta * q_kl / (1 - self.discount)
        q_kl_term = self.beta * q_kl
        q_x = q_mu + q_bonus - q_kl_term
        if not self.kl_improve:
            q_x += q_kl_term.detach()

        with torch.no_grad():
            info = {}

            if log:
                info.update(
                    **{
                        "q_bonus": q_bonus.mean(),
                        "q_scale": q_v_scale.mean(),
                        "q_kl": q_kl.mean(),
                        "q_kl_term": q_kl_term.mean(),
                    }
                )

        self._info.update(**info)
        return q_x.t(), q_mu.t()

    def update(self, first_update=False, log=False, **kwargs):
        if self.episode_wise_norm and self.beta_norm is not None:
            self.baseline_que.append(self.last_baseline)
            self.beta_norm = min(self.baseline_que)
        super().update(first_update=first_update, log=log, **kwargs)

    def update_actor(
        self,
        ctx_modules,
        batch_sa,
        batch_mask,
        log_pis,
        rewards,
        log=False,
    ):
        super(VariationalCriticEnsemble, self).update_actor(
            ctx_modules, batch_sa, batch_mask, log_pis, rewards, log
        )
        # Beta tuning
        with torch.no_grad():
            q_mean = self._info[Q_MEAN]
            q_std = self._info[Q_EPI_STD]
            q_ex = self._info[XTRME_Q]
            self.last_baseline = q_mean

            # if self.beta_norm is None:
            #     self.beta_norm = q_mean
            #     self.beta_std = q_std
            #     self.baseline_que.append(q_mean)
            # else:
            #     self.beta_std.lerp_(q_std, self.lr_norm)
            #
            #     min_baseline = min(self.baseline_que)
            #     is_minimal = q_mean < min_baseline
            #     if is_minimal or not self.episode_wise_norm:
            #         self.baseline_que.append(q_mean)
            #         self.beta_norm = min(self.baseline_que)
            #
            # baseline = self.beta_norm

            if self.beta_norm is None:
                self.beta_norm = q_mean
                self.beta_std = q_std
            elif self.episode_wise_norm:
                self.beta_norm.lerp_(q_mean, self.lr_norm)
                self.beta_std.lerp_(q_std, self.lr_norm)

            if self.smoothed_baseline:
                baseline = self.beta_norm
            else:
                self.baseline_que.append(q_mean)
                baseline = torch.tensor(self.baseline_que).min()

            improvement = torch.mean(
                (q_ex - baseline) / self.beta_norm.abs().clamp(self.beta_min)
            )
            if self.beta_update.startswith("opt"):
                beta_loss = self.beta_loss(improvement, self.raw_beta)
                if self.beta_update == "opt_std":
                    beta_loss *= self.beta_std
                if self.discount_beta:
                    beta_loss *= 1 - self.discount
                self.beta_optimizer.zero_grad()
                beta_loss.backward()
                self.beta_optimizer.step()
                self._info["beta_loss"] = beta_loss.detach()
            else:
                if self.beta_update == "var":
                    raw_beta = self.beta_std.square() * 0.5 / self.target_improvement
                elif self.beta_update == "std":
                    raw_beta = self.beta_std / self.target_improvement
                else:
                    raw_beta = improvement - self.target_improvement
                    raw_beta *= self.beta_std
                if self.discount_beta:
                    raw_beta *= 1 - self.discount

                max_beta = (
                    self.max_beta_inc * self.beta / self.lr_beta
                    if self.max_beta_inc
                    else None
                )
                new_beta = torch.clamp(raw_beta, min=0, max=max_beta)
                old_beta = self.beta
                old_beta.lerp_(new_beta, self.lr_beta)
                self.raw_beta.data.copy_(old_beta.log())
            self.raw_beta.data.clamp_(min=self.beta_log_min, max=self.beta_log_max)
            self._info.update(
                **{
                    "relative_q_std": self.beta_std
                    / self.beta_norm.mean().abs().clamp(self.beta_min),
                    "improvement_v": improvement,
                    "beta": self.beta,
                    "baseline": self.beta_norm.mean(),
                }
            )
            if not self.mean_improve:
                self._info.update(**{Q_MEAN: q_mean.mean(), XTRME_Q: q_ex.mean()})

    @torch.enable_grad()
    def beta_loss(self, improvement, log_beta):
        return -(improvement - self.target_improvement) * log_beta.exp()

    def build_critics(self, critic_cfg):
        super(SVGAgent, self).build_critics(critic_cfg)
        num_critic_ensemble = int(self.critic.num_members / self.num_model_members)

        if self.weighted_critic:
            size = (
                self.batch_size,
                self.num_model_members,
                num_critic_ensemble,
            )
            weight_rate = torch.ones(size, device=self.device)
            self.critic_loss_weight = torch.poisson(weight_rate)[None]

    def init_optimizer(self):
        self._info = defaultdict(list)
        actor_params = [
            {"params": self.actor.parameters()},
            {"params": [self.raw_alpha], "lr": self.lr_alpha},
        ]
        if hasattr(self, "dynamics_model") and hasattr(
            self.dynamics_model, "variational_parameters"
        ):
            actor_params.append(
                {
                    "params": self.dynamics_model.variational_parameters(),
                    "lr": self.lr * self.v_lr_ratio,
                    # "betas": (0.0, 0.9),
                    # "maximize": True,
                }
            )
        if hasattr(self, "behavior_actor"):
            actor_params.append(
                {"params": self.behavior_actor.parameters(), "lr": self.lr}
            )

        self.actor_optimizer = torch.optim.Adam(actor_params, lr=self.lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.lr)

        if self.beta_update.startswith("opt") and hasattr(self, "beta"):
            self.raw_beta.requires_grad = True
            # self.beta_optimizer = torch.optim.SGD(
            #     [self.raw_beta], lr=self.lr_beta, momentum=0.9
            # )
            self.beta_optimizer = torch.optim.Adam([self.raw_beta], lr=self.lr_beta)

    @property
    def beta(self):
        return self.raw_beta.exp()

    @contextmanager
    def policy_evaluation_context(self, detach=False, explore=False, **kwargs):
        if 0 < self.rollout_horizon:
            buffer = self.rollout_buffer
        else:
            buffer = self.replay_buffer

        if explore:
            self.actor.exploring = True

        context_modules = ContextModules(
            self.actor,
            self.actor_optimizer,
            self.pred_q_value,
            self.pred_target_q_value,
            self.pred_terminal_q,
            self.critic,
            self.critic_target,
            self.critic_optimizer,
            self.alpha,
            self.model_step_context,
            buffer,
            detach,
        )
        try:
            yield context_modules
        finally:
            self.actor.exploring = False


# class SoftDistValueGradient(VariationalValueGradient):
#     def _add_x_rewards(self, sa, last_sa, log_pis, rewards, masks, log=False):
#         self.dynamics_model.predict_ts1_full(sa, tile=self.training_rollout_horizon)
#         reward_entropy = (
#             self.dynamics_model.state_entropy[..., :1]
#             + self.output_normalizer.std[0, 0].log()
#         )
#         reward_entropy_term = (
#             reward_entropy.view(self.training_rollout_horizon, self.batch_size)
#             * self.beta.detach()
#         )
#         reward_x = rewards + reward_entropy_term
#         pred_q_x, pred_q = self.pred_terminal_q(last_sa, log=log)
#
#         reward_x = torch.cat([reward_x, pred_q_x])
#         reward_x = reward_x - self.alpha.detach() * log_pis
#
#         vmve = self.discount_mat.mm(reward_x * masks)
#         actor_loss = -vmve.mean()
#
#         with torch.no_grad():
#             # if not self.separate_actor:
#             rewards = torch.cat([rewards, pred_q]).clone()
#             rewards[1:] -= self.alpha * log_pis[1:]
#             baseline = self.discount_mat.mm(rewards * masks)[0]
#             sdqe = vmve[0] + self.alpha * log_pis[0, :]
#             if self.mean_improve:
#                 self._info[Q_MEAN] = baseline.mean()
#                 self._info[XTRME_Q] = sdqe.mean()
#             else:
#                 self._info[Q_MEAN] = baseline
#                 self._info[XTRME_Q] = sdqe
#             if log:
#                 self._info.update(
#                     **{
#                         "reward_entropy": reward_entropy.mean(),
#                         "reward_entropy_term": reward_entropy_term.mean(),
#                     }
#                 )
#         return actor_loss
#
#     def pred_terminal_q(self, obs_action, log=False, batch_size=None):
#         """
#
#         Args:
#             obs_action: [B, D]
#
#         Returns:
#             pred_q: [B, 1]
#
#         """
#         if batch_size is None:
#             batch_size = self.batch_size
#         self.critic.train(False)
#
#         q_mcb = (
#             self.critic(obs_action)
#             .t()
#             .view(self.num_model_members, self.num_critic_ensemble, batch_size)
#         )
#         q_mb = q_mcb.mean(1)
#
#         q_mu = q_mb.mean(0)[..., None]
#         q_std = q_mb.std(0)[..., None]
#         q_entropy = (np.log(2 * np.pi) + 1) * 0.5 + (q_std * self._q_width).log()
#         q_entropy_term = q_entropy * self.beta.detach()
#         q_x = q_mu + q_entropy_term
#
#         if log:
#             with torch.no_grad():
#                 self._info.update(
#                     **{
#                         "q_entropy": q_entropy.mean(),
#                         "q_entropy_term": q_entropy_term.mean(),
#                     }
#                 )
#         return q_x.t(), q_mu.t()
#
#     def update_actor(
#         self,
#         ctx_modules,
#         batch_sa,
#         batch_mask,
#         log_pis,
#         rewards,
#         log=False,
#     ):
#         super().update_actor(ctx_modules, batch_sa, batch_mask, log_pis, rewards, log)
#         if 0.0 < self.lr_beta:
#             info = {}
#
#             with torch.no_grad():
#                 q_mean = self._info[Q_MEAN]
#                 q_soft = self._info[XTRME_Q]
#
#                 if self.beta_norm is None:
#                     self.beta_norm = q_mean
#                     # self.beta_std = q_std
#                 else:
#                     self.beta_norm.lerp_(q_mean, self.lr_norm)
#                     # self.beta_std.lerp_(q_std, self.lr_norm)
#
#                 # if self.eta_update == "opt_speed":
#                 #     norm = max(torch.abs(self.ema_fast - self.ema_slow), self.eta_min)
#                 #     change_degree = torch.abs(q_mean - self.ema_fast) / norm
#                 #     eta_loss = self.eta_speed_loss(change_degree, self.raw_eta)
#                 #     info["change_degree"] = change_degree
#                 # else:
#                 q_entropy = (q_soft - q_mean) / self.beta.detach()
#                 beta_loss, entropy_term_ratio = self.eta_relative_loss(
#                     q_entropy, self.beta_norm, self.raw_beta
#                 )
#                 info["entropy_ratio"] = entropy_term_ratio
#                 self.beta_optimizer.zero_grad()
#                 beta_loss.backward()
#                 self.beta_optimizer.step()
#
#                 info["eta_loss"] = beta_loss.detach()
#
#                 self.raw_beta.data.clamp_(min=np.log(self.beta_min))
#             self._info.update(**info, eta=self.beta)
#
#     @torch.enable_grad()
#     def eta_relative_loss(self, q_entropy, value, raw_beta):
#         entropy_term = q_entropy * raw_beta.exp()
#         denom = torch.abs(value - entropy_term.detach()).clamp_min(self.beta_min)
#         entropy_term_ratio = torch.abs(entropy_term) / denom
#         return (
#             torch.square(entropy_term_ratio - self.target_improvement),
#             entropy_term_ratio.detach(),
#         )
