import copy
import gc

import torch
from torch import nn

import networks
import tools
import numpy as np #[todo]
to_np = lambda x: x.detach().cpu().numpy()


class RewardEMA:
    """running mean and std"""

    def __init__(self, device, alpha=1e-2):
        self.device = device
        self.alpha = alpha
        self.range = torch.tensor([0.05, 0.95], device=device)

    def __call__(self, x, ema_vals):
        flat_x = torch.flatten(x.detach())
        x_quantile = torch.quantile(input=flat_x, q=self.range)
        # this should be in-place operation
        ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
        scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
        offset = ema_vals[0]
        return offset.detach(), scale.detach()


class WorldModel(nn.Module):
    def __init__(self, obs_space, act_space, step, config):
        super(WorldModel, self).__init__()
        self._step = step
        self._use_amp = True if config.precision == 16 else False
        self._config = config
        shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
        self.encoder = networks.MultiEncoder(shapes, **config.encoder)
        #[todo] start
        if self._config.encoder_with_moe:
            for param in self.encoder.parameters():
                param.requires_grad = False
        #[todo] end
        self.embed_size = self.encoder.outdim
        self.dynamics = networks.RSSM(
            config.dyn_stoch,
            config.dyn_deter,
            config.dyn_hidden,
            config.dyn_rec_depth,
            config.dyn_discrete,
            config.act,
            config.norm,
            config.dyn_mean_act,
            config.dyn_std_act,
            config.dyn_min_std,
            config.unimix_ratio,
            config.initial,
            config.num_actions,
            self.embed_size,
            config.device,
            config=config, #[todo]
        )
        self.heads = nn.ModuleDict()
        if config.dyn_discrete:
            feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
        else:
            feat_size = config.dyn_stoch + config.dyn_deter
        if not self._config.head_with_moe: #[todo]
            self.heads["decoder"] = networks.MultiDecoder(
                feat_size, shapes, **config.decoder
            )
            self.heads["reward"] = networks.MLP(
                feat_size,
                (255,) if config.reward_head["dist"] == "symlog_disc" else (),
                config.reward_head["layers"],
                config.units,
                config.act,
                config.norm,
                dist=config.reward_head["dist"],
                outscale=config.reward_head["outscale"],
                device=config.device,
                name="Reward",
            )
            self.heads["cont"] = networks.MLP(
                feat_size,
                (),
                config.cont_head["layers"],
                config.units,
                config.act,
                config.norm,
                dist="binary",
                outscale=config.cont_head["outscale"],
                device=config.device,
                name="Cont",
            )
            for name in config.grad_heads:
                assert name in self.heads, name
        #[todo] start
        else:
            if "decoder" not in self._config.moe_heads_list:
                self.heads["decoder"] = networks.MultiDecoder(
                    feat_size, shapes, **config.decoder
                )
            if "reward" not in self._config.moe_heads_list:
                self.heads["reward"] = networks.MLP(
                    feat_size,
                    (255,) if config.reward_head["dist"] == "symlog_disc" else (),
                    config.reward_head["layers"],
                    config.units,
                    config.act,
                    config.norm,
                    dist=config.reward_head["dist"],
                    outscale=config.reward_head["outscale"],
                    device=config.device,
                    name="Reward",
                )
            if "cont" not in self._config.moe_heads_list:
                self.heads["cont"] = networks.MLP(
                    feat_size,
                    (),
                    config.cont_head["layers"],
                    config.units,
                    config.act,
                    config.norm,
                    dist="binary",
                    outscale=config.cont_head["outscale"],
                    device=config.device,
                    name="Cont",
                )
        #[todo] end
        self._model_opt = tools.Optimizer(
            "model",
            self.parameters(),
            config.model_lr,
            config.opt_eps,
            config.grad_clip,
            config.weight_decay,
            opt=config.opt,
            use_amp=self._use_amp,
        )
        print(
            f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
        )
        # other losses are scaled by 1.0.
        self._scales = dict(
            reward=config.reward_head["loss_scale"],
            cont=config.cont_head["loss_scale"],
        )

        #[todo] start
        self.moe_has_opt = False
        self.experts = None
        self.feat_size = feat_size
        self.task_group_map = None
        #[todo] end

    #[todo] start
    def init_moe(self):
        if self.dynamics.has_moe:
            return
        if self.experts is None:
            self.dynamics.init_moe()
        else:
            print(f"Expert type: RSSM")
            #此时的self.experts是多个_wm
            self._config.experts_name_list = list(self.experts.keys())
            moe_layers_list = ["img_out","obs_out","img_stat","obs_stat"]
            moe_layers_list = ["img_out","img_stat"]

            name_layer_map = {"img_out":"_img_out_layers",
                              "obs_out":"_obs_out_layers",
                              "img_stat":"_imgs_stat_layer",
                              "obs_stat":"_obs_stat_layer"}
            for name in moe_layers_list:
                experts = nn.ModuleList()
                for key in self._config.experts_name_list:
                    experts.append(getattr(self.experts[key].dynamics, name_layer_map[name]))
                self.dynamics.init_moe_with_rssm_expert(name, experts)
            # self.dynamics.has_moe = True

            if self._config.encoder_with_moe:
                print(f"Initing Encoder MoE")
                experts = nn.ModuleList()
                for key in self._config.experts_name_list:
                    experts.append(self.experts[key].encoder)
                self.encoder = networks.MoE(self.feat_size, experts=experts, multi_stage=self._config.multi_stage, config=self._config, use_router=False).to(self._config.device)

            if self._config.head_with_moe:
                print(f"Initing heads MoE")
                new_params = []
                for name in self._config.moe_heads_list:
                    experts = nn.ModuleList()
                    for key in self._config.experts_name_list:
                        experts.append(self.experts[key].heads[name])
                    self.heads[name] = networks.MoE(self.feat_size, experts=experts, multi_stage=self._config.multi_stage, config=self._config).to(self._config.device)
                    new_params.extend([param for param in self.heads[name].parameters() if param.requires_grad])
                    self.add_opt_params(new_params)
            self.dynamics.has_moe = True
        moe_params = self.dynamics.get_moe_params()
        if not self.moe_has_opt:
            self.add_opt_params(moe_params)
            self.moe_has_opt = True
        print(f"Init moe success.")
        # print(f"self._model_opt._opt.param_groups[0]:{self._model_opt._opt.param_groups[0]}")
    def remove_experts(self):
        del self.experts
        gc.collect()
        torch.cuda.empty_cache()
        print(f"self.dynamics._img_out_layers_moe:{self.dynamics._img_out_layers_moe}")
        self.experts = None



    def add_opt_params(self, params):
        # 获取原有参数组（通常第一个是默认组）
        original_param_group = self._model_opt._opt.param_groups[0]
        # 向原有参数组的params中添加参数
        original_param_group['params'].extend(params)

    def save_dpmm(self):
        self.dynamics.dpmm.save_all()

    def _train_router_stage(self, data, step=None, total_update=None, N=None):
        """
        将 data 按 size 维拆分为 N 份，分别输入 X()，然后拼接返回结果。
        - data: dict[str, np.ndarray]，其中每个 value 的第 0 维是 batch size
        - X: 函数，输入为 data 子集，输出为 (post, context, metrics)
          post/context 为 dict[str, torch.Tensor]
          metrics 为 dict[str, 标量]
        - N: 拆分份数
        """
        if N is None:
            N = len(self._config.train_env_name_list)
        size = list(data.values())[0].shape[0]  # 假设所有 key 的 size 维一致
        indices = np.linspace(0, size, N + 1, dtype=int)

        if self._config.encoder_with_moe:
            if self.task_group_map is None:
                self.task_group_map = {
                    task: group_name
                    for group_name, task_list in self._config.grouping_dict.items()
                    for task in task_list
                }

        posts, contexts, metrics_list = [], [], []

        for i in range(N):
            start, end = indices[i], indices[i + 1]
            # numpy 切片
            data_i = {k: v[start:end] for k, v in data.items()}

            # if self._config.encoder_with_moe:
            #     group_name = self.task_group_map[self._config.train_env_name_list[i]]
            #     self.encoder = self.experts[group_name].encoder

            # 调用模型函数
            post_i, context_i, metrics_i = self._train(data_i, step=step, total_update=total_update, env_name=self._config.train_env_name_list[i])

            # 收集结果
            posts.append(post_i)
            contexts.append(context_i)
            # 在 key 前加上前缀
            metrics_list.append({f"{self._config.train_env_name_list[i]}_{k}": v for k, v in metrics_i.items()})

        # 拼接 post / context
        post_all = {
            k: torch.cat([p[k] for p in posts], dim=0)
            for k in posts[0].keys()
        }
        context_all = {
            k: torch.cat([c[k] for c in contexts], dim=0)
            for k in contexts[0].keys()
        }

        # 合并 metrics
        metrics_all = {}
        for m in metrics_list:
            metrics_all.update(m)

        return post_all, context_all, metrics_all

    #[todo] end

    def _train(self, data, step=None, total_update=None, env_name=None): #[todo]
        # action (batch_size, batch_length, act_dim)
        # image (batch_size, batch_length, h, w, ch)
        # reward (batch_size, batch_length)
        # discount (batch_size, batch_length)
        data = self.preprocess(data)
        # print(f"step in wm._train:{step}")
        # print(f"data:{data}")
        # print(f"data['action']:{data['action']}")
        # print(f"data['action'].shape:{data['action'].shape}")
        # print(f"data['action'][0,:,:]:{data['action'][0,:,:]}")
        # print(f"data['token_embed']:{data['token_embed']}")
        with tools.RequiresGrad(self):
            with torch.cuda.amp.autocast(self._use_amp):
                embed = self.encoder(data, env_name=env_name) #[todo]
                post, prior = self.dynamics.observe(
                    embed, data["action"], data["is_first"], env_name=env_name, step=step, total_update=total_update #[todo]
                )
                kl_free = self._config.kl_free
                dyn_scale = self._config.dyn_scale
                rep_scale = self._config.rep_scale
                kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
                    post, prior, kl_free, dyn_scale, rep_scale
                )
                assert kl_loss.shape == embed.shape[:2], kl_loss.shape
                preds = {}
                for name, head in self.heads.items():
                    grad_head = name in self._config.grad_heads
                    feat = self.dynamics.get_feat(post)
                    feat = feat if grad_head else feat.detach()
                    pred = head(feat, env_name=env_name) #[todo]
                    if type(pred) is dict:
                        preds.update(pred)
                    else:
                        preds[name] = pred
                losses = {}
                for name, pred in preds.items():
                    loss = -pred.log_prob(data[name])
                    assert loss.shape == embed.shape[:2], (name, loss.shape)
                    losses[name] = loss
                scaled = {
                    key: value * self._scales.get(key, 1.0)
                    for key, value in losses.items()
                }
                model_loss = sum(scaled.values()) + kl_loss
            metrics = self._model_opt(torch.mean(model_loss), self.parameters())

        metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
        metrics["kl_free"] = kl_free
        metrics["dyn_scale"] = dyn_scale
        metrics["rep_scale"] = rep_scale
        metrics["dyn_loss"] = to_np(dyn_loss)
        metrics["rep_loss"] = to_np(rep_loss)
        metrics["kl"] = to_np(torch.mean(kl_value))
        with torch.cuda.amp.autocast(self._use_amp):
            metrics["prior_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(prior).entropy())
            )
            metrics["post_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(post).entropy())
            )
            context = dict(
                embed=embed,
                feat=self.dynamics.get_feat(post),
                kl=kl_value,
                postent=self.dynamics.get_dist(post).entropy(),
            )
        post = {k: v.detach() for k, v in post.items()}
        return post, context, metrics

    # this function is called during both rollout and training
    def preprocess(self, obs):
        obs = {
            k: torch.tensor(v, device=self._config.device, dtype=torch.float32)
            for k, v in obs.items() if k != "env_name" #[todo]
        }
        obs["image"] = obs["image"] / 255.0
        if "discount" in obs:
            obs["discount"] *= self._config.discount
            # (batch_size, batch_length) -> (batch_size, batch_length, 1)
            obs["discount"] = obs["discount"].unsqueeze(-1)
        # 'is_first' is necesarry to initialize hidden state at training
        assert "is_first" in obs
        # 'is_terminal' is necesarry to train cont_head
        assert "is_terminal" in obs
        obs["cont"] = (1.0 - obs["is_terminal"]).unsqueeze(-1)
        return obs

    def video_pred(self, data, env_name=None): #[todo]
        data = self.preprocess(data)
        embed = self.encoder(data, env_name=env_name) #[todo]

        states, _ = self.dynamics.observe(
            embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5], env_name=env_name #[todo]
        )
        recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[
            :6
        ]
        reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
        init = {k: v[:, -1] for k, v in states.items()}
        prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init, env_name=env_name) #[todo]
        openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
        reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
        # observed image is given until 5 steps
        model = torch.cat([recon[:, :5], openl], 1)
        truth = data["image"][:6]
        model = model
        error = (model - truth + 1.0) / 2.0

        return torch.cat([truth, model, error], 2)

    #[todo] start
    def get_input_gradients(self, raw_data):

        # # 训练中真正需要梯度的输入
        # NEED_GRAD = {"image", "reward", "discount", "is_first", "is_terminal","cont"}
        data = self.preprocess(raw_data)
        original_param = copy.deepcopy(self.state_dict())
        # self.train()

        # # NEED_GRAD = list(data.keys())
        # data = {}
        # for k, v in raw_data.items():
        #     t = torch.tensor(v, device=self._config.device, dtype=v.dtype)
        #
        #     if k in NEED_GRAD:
        #         data[k] = t.clone().detach().requires_grad_(True)
        #     else:
        #         # 不在训练中产生梯度的输入
        #         data[k] = t.clone().detach()
        with tools.RequiresGrad(self):
            with torch.cuda.amp.autocast(self._use_amp):
                embed = self.encoder(data)
                # for k in data.keys():
                #     print(f"data[{k}].shape:{data[k].shape}")
                #     print(f"data[{k}].requires_grad:{data[k].requires_grad}")
                # print(f"embed.requires_grad:{embed.requires_grad}")
                post, prior = self.dynamics.observe(
                    embed, data["action"], data["is_first"]
                )
                kl_free = self._config.kl_free
                dyn_scale = self._config.dyn_scale
                rep_scale = self._config.rep_scale
                kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
                    post, prior, kl_free, dyn_scale, rep_scale
                )
                assert kl_loss.shape == embed.shape[:2], kl_loss.shape
                preds = {}
                for name, head in self.heads.items():
                    grad_head = name in self._config.grad_heads
                    feat = self.dynamics.get_feat(post)
                    feat = feat if grad_head else feat.detach()
                    pred = head(feat)
                    if type(pred) is dict:
                        preds.update(pred)
                    else:
                        preds[name] = pred
                losses = {}
                for name, pred in preds.items():
                    loss = -pred.log_prob(data[name])
                    assert loss.shape == embed.shape[:2], (name, loss.shape)
                    losses[name] = loss
                scaled = {
                    key: value * self._scales.get(key, 1.0)
                    for key, value in losses.items()
                }
                loss = (sum(scaled.values()) + kl_loss).mean()

                scaler = self._model_opt._scaler  # 训练中的 scaler
                self._model_opt._opt.zero_grad()

                scaler.scale(loss).backward()
                scaler.unscale_(self._model_opt._opt)
                # for n, p in self.named_parameters():
                #     print(n, p.grad is None, p.requires_grad)



        # ===== Forward ======
        # embed = self.encoder(data)
        # post, prior = self.dynamics.observe(embed, data["action"], data["is_first"])
        #
        # kl_free = self._config.kl_free
        # dyn_scale = self._config.dyn_scale
        # rep_scale = self._config.rep_scale
        # kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
        #     post, prior, kl_free, dyn_scale, rep_scale
        # )
        #
        # preds = {}
        # feat = self.dynamics.get_feat(post)
        # for name, head in self.heads.items():
        #     pred = head(feat)
        #     preds.update(pred if isinstance(pred, dict) else {name: pred})
        #
        # losses = {k: -pred.log_prob(data[k]) for k, pred in preds.items()}
        # scaled = {k: v * self._scales.get(k, 1.0) for k, v in losses.items()}
        # loss = (sum(scaled.values()) + kl_loss).mean()

        # ===== Backward =====

        # self._model_opt._opt.zero_grad()
        # # loss.backward()
        # self._model_opt._scaler.scale(loss).backward()




        # ===== 提取真正有梯度的输入 =====
        # def parameters_to_gradvector(net):
        # grads = self.parameters_to_gradvector()
        # print(f"embed.grad:{embed.grad}")
        norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self._config.grad_clip)
        gradient, shapes = self.parameters_to_gradvector()
        self.load_state_dict(original_param)

        # grads = {
        #     k: data[k].grad.clone().detach()
        #     for k in NEED_GRAD
        #     if data[k].grad is not None
        # }

        # return grads, loss.detach()
        return gradient, shapes

    def parameters_to_gradvector(self):
        vec = []
        shapes = []
        for name, param in self.named_parameters():
            if param.grad is not None:
                # print(f"{name}:param.grad is not None")
                vec.append(param.grad.view(-1))
            else:
                vec.append(torch.zeros(param.size()).view(-1).cuda())
            shapes.append(param.shape)
        return torch.cat(vec), shapes



class ImagBehavior(nn.Module):
    def __init__(self, config, world_model):
        super(ImagBehavior, self).__init__()
        self._use_amp = True if config.precision == 16 else False
        self._config = config
        self._world_model = world_model
        if config.dyn_discrete:
            feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
        else:
            feat_size = config.dyn_stoch + config.dyn_deter
        self.actor = networks.MLP(
            feat_size,
            (config.num_actions,),
            config.actor["layers"],
            config.units,
            config.act,
            config.norm,
            config.actor["dist"],
            config.actor["std"],
            config.actor["min_std"],
            config.actor["max_std"],
            absmax=1.0,
            temp=config.actor["temp"],
            unimix_ratio=config.actor["unimix_ratio"],
            outscale=config.actor["outscale"],
            name="Actor",
        )
        self.value = networks.MLP(
            feat_size,
            (255,) if config.critic["dist"] == "symlog_disc" else (),
            config.critic["layers"],
            config.units,
            config.act,
            config.norm,
            config.critic["dist"],
            outscale=config.critic["outscale"],
            device=config.device,
            name="Value",
        )
        if config.critic["slow_target"]:
            self._slow_value = copy.deepcopy(self.value)
            self._updates = 0
        kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
        self._actor_opt = tools.Optimizer(
            "actor",
            self.actor.parameters(),
            config.actor["lr"],
            config.actor["eps"],
            config.actor["grad_clip"],
            **kw,
        )
        print(
            f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
        )
        self._value_opt = tools.Optimizer(
            "value",
            self.value.parameters(),
            config.critic["lr"],
            config.critic["eps"],
            config.critic["grad_clip"],
            **kw,
        )
        print(
            f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
        )
        if self._config.reward_EMA:
            # register ema_vals to nn.Module for enabling torch.save and torch.load
            self.register_buffer(
                "ema_vals", torch.zeros((2,), device=self._config.device)
            )
            self.reward_ema = RewardEMA(device=self._config.device)

    def _train(
        self,
        start,
        objective,
        env_name=None #[todo]
    ):
        self._update_slow_target()
        metrics = {}

        with tools.RequiresGrad(self.actor):
            with torch.cuda.amp.autocast(self._use_amp):
                imag_feat, imag_state, imag_action = self._imagine(
                    start, self.actor, self._config.imag_horizon, env_name=env_name #[todo]
                )
                reward = objective(imag_feat, imag_state, imag_action, env_name=env_name) #[todo]
                actor_ent = self.actor(imag_feat).entropy()
                state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
                # this target is not scaled by ema or sym_log.
                target, weights, base = self._compute_target(
                    imag_feat, imag_state, reward
                )
                actor_loss, mets = self._compute_actor_loss(
                    imag_feat,
                    imag_action,
                    target,
                    weights,
                    base,
                )
                actor_loss -= self._config.actor["entropy"] * actor_ent[:-1, ..., None]
                actor_loss = torch.mean(actor_loss)
                metrics.update(mets)
                value_input = imag_feat

        with tools.RequiresGrad(self.value):
            with torch.cuda.amp.autocast(self._use_amp):
                value = self.value(value_input[:-1].detach())
                target = torch.stack(target, dim=1)
                # (time, batch, 1), (time, batch, 1) -> (time, batch)
                value_loss = -value.log_prob(target.detach())
                slow_target = self._slow_value(value_input[:-1].detach())
                if self._config.critic["slow_target"]:
                    value_loss -= value.log_prob(slow_target.mode().detach())
                # (time, batch, 1), (time, batch, 1) -> (1,)
                value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
        # try:
        #     print(f"start.shape:{start.shape}")
        # except:
        #     print(f"start type:{type(start)}")
        #     print(f"start['stoch'].shape:{start['stoch'].shape}")
        # try:
        #     print(f"imag_feat.shape:{imag_feat.shape}")
        # except:
        #     print(f"imag_feat type:{type(imag_feat)}")
        # try:
        #     print(f"imag_action.shape:{imag_action.shape}")
        # except:
        #     print(f"imag_action type:{type(imag_action)}")
        # try:
        #     print(f"reward.shape:{reward.shape}")
        # except:
        #     print(f"reward type:{type(reward)}")
        # try:
        #     print(f"value.shape:{value.shape}")
        # except:
        #     print(f"value type:{type(value)}")
        # try:
        #     print(f"target.shape:{target.shape}")
        # except:
        #     print(f"target type:{type(target)}")
        metrics.update(tools.tensorstats(value.mode(), "value"))
        metrics.update(tools.tensorstats(target, "target"))
        metrics.update(tools.tensorstats(reward, "imag_reward"))
        if self._config.actor["dist"] in ["onehot"]:
            metrics.update(
                tools.tensorstats(
                    torch.argmax(imag_action, dim=-1).float(), "imag_action"
                )
            )
        else:
            metrics.update(tools.tensorstats(imag_action, "imag_action"))
        metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
        with tools.RequiresGrad(self):
            metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
            metrics.update(self._value_opt(value_loss, self.value.parameters()))
        return imag_feat, imag_state, imag_action, weights, metrics

    def _imagine(self, start, policy, horizon, env_name=None): #[todo]
        dynamics = self._world_model.dynamics
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start.items()}
        # print(f"start['stoch'].shape in _imagine:{start['stoch'].shape}")

        def step(prev, _, env_name=None): #[todo]
            state, _, _ = prev
            feat = dynamics.get_feat(state)
            inp = feat.detach()
            action = policy(inp).sample()
            succ = dynamics.img_step(state, action, env_name=env_name) #[todo]
            return succ, feat, action

        succ, feats, actions = tools.static_scan(
            step, [torch.arange(horizon), env_name], (start, None, None) #succ对应的应该是step里面的prev[0] [todo]
        )
        states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}

        return feats, states, actions

    def _compute_target(self, imag_feat, imag_state, reward):
        if "cont" in self._world_model.heads:
            inp = self._world_model.dynamics.get_feat(imag_state)
            discount = self._config.discount * self._world_model.heads["cont"](inp).mean
        else:
            discount = self._config.discount * torch.ones_like(reward)
        value = self.value(imag_feat).mode()
        target = tools.lambda_return(
            reward[1:],
            value[:-1],
            discount[1:],
            bootstrap=value[-1],
            lambda_=self._config.discount_lambda,
            axis=0,
        )
        weights = torch.cumprod(
            torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
        ).detach()
        return target, weights, value[:-1]

    def _compute_actor_loss(
        self,
        imag_feat,
        imag_action,
        target,
        weights,
        base,
    ):
        metrics = {}
        inp = imag_feat.detach()
        policy = self.actor(inp)
        # Q-val for actor is not transformed using symlog
        target = torch.stack(target, dim=1)
        if self._config.reward_EMA:
            offset, scale = self.reward_ema(target, self.ema_vals)
            normed_target = (target - offset) / scale
            normed_base = (base - offset) / scale
            adv = normed_target - normed_base
            metrics.update(tools.tensorstats(normed_target, "normed_target"))
            metrics["EMA_005"] = to_np(self.ema_vals[0])
            metrics["EMA_095"] = to_np(self.ema_vals[1])

        if self._config.imag_gradient == "dynamics":
            actor_target = adv
        elif self._config.imag_gradient == "reinforce":
            actor_target = (
                policy.log_prob(imag_action)[:-1][:, :, None]
                * (target - self.value(imag_feat[:-1]).mode()).detach()
            )
        elif self._config.imag_gradient == "both":
            actor_target = (
                policy.log_prob(imag_action)[:-1][:, :, None]
                * (target - self.value(imag_feat[:-1]).mode()).detach()
            )
            mix = self._config.imag_gradient_mix
            actor_target = mix * target + (1 - mix) * actor_target
            metrics["imag_gradient_mix"] = mix
        else:
            raise NotImplementedError(self._config.imag_gradient)
        actor_loss = -weights[:-1] * actor_target
        return actor_loss, metrics

    def _update_slow_target(self):
        if self._config.critic["slow_target"]:
            if self._updates % self._config.critic["slow_target_update"] == 0:
                mix = self._config.critic["slow_target_fraction"]
                for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1
