import math
import numpy as np
import re

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd

import tools

#[todo] start
import os
from torch.distributions.utils import broadcast_all
from itertools import cycle
# import bnpy.bnpy as bnpy
# from bnpy.data.XData import XData
from scipy import stats
from matplotlib import pylab
from transformers.activations import ACT2FN
from transformers.modeling_utils import (
    Conv1D,
    PreTrainedModel,
    SequenceSummary,
    find_pruneable_heads_and_indices,
    prune_conv1d_layer,
)
class Expert(nn.Module):
    def __init__(self, n_state, n_embed, config):
        '''
        n_state一般是4*embed_dim或者4*n_embed，同时是输入维度和输出维度
        n_embed对应的是backbone的hidden_size
        '''
        super().__init__()
        nx = n_embed
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = ACT2FN[config.expert_act]
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)

class MoE(nn.Module):
    def __init__(self, input_dim, num_experts=None, experts=None, multi_stage=False, config=None, use_router=None):
        super().__init__()
        if num_experts is None:
            num_experts = len(experts)
        self.use_router = config.wm_use_router if use_router is None else use_router
        if self.use_router:
            self.router=Router(input_dim, 2* input_dim, num_experts)
            self.router.apply(tools.weight_init)
        self.experts=experts

        self._config = config
        # print(
        #     f"MoE has {sum(param.numel() for param in self.parameters() if param.requires_grad)} variables requiring grads."
        # )
        if multi_stage:
            for expert in self.experts:
                for param in expert.parameters():
                    param.requires_grad = False
        # print(
        #     f"MoE has {sum(param.numel() for param in self.parameters() if param.requires_grad)} variables requiring grads."
        # )
        self.size_per_env = len(self._config.train_env_name_list) * self._config.batch_size
        if self._config.train_env_name_list != "None":
            # 提取类别前缀 + 生成唯一类别（按首次出现顺序）
            cats = [e.split("_")[0] for e in self._config.train_env_name_list]
            self.unique_cats = list(dict.fromkeys(cats))  # 去重且保留顺序
            if len(self.experts) < len(self.unique_cats):
                raise NotImplementedError("experts nums is not enough according to current method")
            self.cat2idx = {c: i for i, c in enumerate(self.unique_cats)}
            # 生成「环境维度→类别下标」映射，暂先不扩展到batch维度（核心向量化）
            self.env_cat_idx = torch.tensor([self.cat2idx[c] for c in cats])  # 环境级下标 (env_nums,)
        # if isinstance(experts[0], MultiEncoder):
        #     print(f"MoE of MultiEncoder")

    def forward(self, x, env_name=None):
        output_is_tensor = True
        if "RSSM" not in self._config.expert_type: #目前只有expert_type是RSSM/RSSM_append的时候输入才和输出维度不一致
            outputs = torch.zeros_like(x)
            expert_outputs = None
        else:
            outputs = None
            expert_outputs = []

        if self.use_router:
            route_weight=self.router(x) # shape: [..., num_experts]
            for i, expert in enumerate(self.experts):
                expert_output = expert(x)
                if outputs is not None:
                    # 其实也许可以直接替换成[...,i]
                    if len(route_weight.shape) == 3:
                        outputs += route_weight[:,:,i].unsqueeze(-1)*expert_output
                    elif len(route_weight.shape) == 2:
                        outputs += route_weight[:,i].unsqueeze(-1)*expert_output
                    else:
                        return NotImplementedError("route_weight shape not considered")
                else: #目前只有expert_type==RSSM时才有可能在heads的MoE里面输出dict/dist
                    if i == 0:
                        is_dist_or_dict = self.is_dist_or_dict(expert_output)
                    if not is_dist_or_dict:
                        # route_weight[..., i]: same shape as x[..., input_dim] except last dim
                        weight = route_weight[..., i].unsqueeze(-1)  # broadcast on output_dim
                        expert_outputs.append(weight * expert_output)
                    else:
                        expert_outputs.append(expert_output)
                        output_is_tensor = False

            #目前output不是tensor的只能是dict或者dist
            if not output_is_tensor:
                # === 分布输出 ===
                if isinstance(expert_outputs[0], dict):
                    outputs = {}
                    for k in expert_outputs[0].keys():
                        dists = [out[k] for out in expert_outputs]
                        outputs[k] = MultiDist(dists, route_weight)
                else:
                    outputs = MultiDist(expert_outputs, route_weight)
            else:
                # 聚合
                if expert_outputs is not None:
                    outputs = torch.stack(expert_outputs, dim=0).sum(dim=0)
        else:
            if not isinstance(x, dict):
                if (x.shape[0] != self.size_per_env) and (x.shape[0] != self._config.envs): #video_pred / actor训练
                    if self._config.expert_type != "RSSM":
                        # print(f"Path 1 in MoE")
                        # # >>>>>>>如果一个env对应一个expert>>>>>>>>>>>>>>>>
                        # outputs = self.experts[self._config.train_env_name_list.index(env_name)](x)
                        # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
                        # >>>>>>>如果多个env对应一个expert>>>>>>>>>>>>>>>>(其实应该是可以兼容一个env对应一个expert的)
                        expert_idx = self.cat2idx[env_name.split("_")[0]]
                        outputs = self.experts[expert_idx](x)
                    else:
                        expert_idx = self._config.experts_name_list.index(self._config.task_group_map[env_name])
                        outputs = self.experts[expert_idx](x)

                    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
                else: #world model训练/ agent _policy(对应的x.shape应该刚好就是len(envs)）
                    if self._config.expert_type != "RSSM":
                        # print(f"Path 2 in MoE")
                        # # >>>>>>>如果一个env对应一个expert>>>>>>>>>>>>>>>>
                        # # 1. 按 dim=0 拆成envs数目个
                        # xs = torch.chunk(x, chunks=len(self._config.train_env_name_list), dim=0)
                        #
                        # # 2. 让每个网络处理自己的输入
                        # outs = [net(xi) for net, xi in zip(self.experts, xs)]
                        #
                        # # 3. 合并成 (self.size_per_env, hidden)
                        # outputs = torch.cat(outs, dim=0)
                        # # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
                        # >>>>>>>如果多个env对应一个expert(不过不同环境到底属于哪个expert是按照任务执行者类型划分的，比如cheetah,walker之类的，这样self.unique_cats的长度才会等于self.experts的长度)>>>>>>>>>>>>>>>>(其实应该是可以兼容一个env对应一个expert的)
                        batch_size = x.shape[0] // len(self._config.train_env_name_list)  # 自动计算batch_size(适配输入的维度是不同的env_nums倍数)
                        batch_cat_idx = self.env_cat_idx.repeat_interleave(batch_size)  # 扩展到batch (batch×env_nums,)
                        for cat_idx in range(len(self.unique_cats)):
                            # 找到该类别对应的所有位置 + 过子网络 + 回填
                            mask = batch_cat_idx == cat_idx
                            outputs[mask] = self.experts[cat_idx](x[mask])
                    else:
                        xs = torch.chunk(x, chunks=len(self._config.train_env_name_list), dim=0)
                        outputs = []
                        for i in range(len(self._config.train_env_name_list)):
                            expert_idx = self._config.experts_name_list.index(self._config.task_group_map[self._config.train_env_name_list[i]])
                            outputs.append(self.experts[expert_idx](xs[i]))
                        outputs = torch.cat(outputs, dim=0)

                    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
            else: #假如expert是encoder，目前仅有 RSSM expert有
                if env_name is not None: #video_pred/world model train
                    expert_idx = self._config.experts_name_list.index(self._config.task_group_map[env_name])
                    outputs = self.experts[expert_idx](x)
                else:# _policy
                    batch_size = next(iter(x.values())).size(0) // len(self._config.train_env_name_list)  # 自动计算batch_size(适配输入的维度是不同的env_nums倍数)
                    # 按环境拆分数据
                    split_data = [
                        {k: torch.split(v, batch_size, dim=0)[i] for k, v in x.items()}
                        for i in range(len(self._config.train_env_name_list))
                    ]

                    # 分别送入不同的网络
                    outputs = []
                    for i in range(len(self._config.train_env_name_list)):
                        env_input = split_data[i]
                        expert_idx = self._config.experts_name_list.index(self._config.task_group_map[self._config.train_env_name_list[i]])
                        out = self.experts[expert_idx](env_input)
                        outputs.append(out)

                    # 拼接输出
                    outputs = torch.cat(outputs, dim=0)

        return outputs

    def is_dist_or_dict(self, instance):
        """
        参数：
            instance: 待判断的任意Python实例
        返回：
            bool: True/ False
        """
        # 定义所有可能返回的类型集合
        possible_types = (
            tools.SampleDist,
            tools.ContDist,
            tools.OneHotDist,
            tools.Bernoulli,
            tools.DiscDist,
            tools.SymlogDist,
            torchd.Distribution,
            dict
        )

        # 核心判断：实例是否属于上述类型之一（支持继承关系，鲁棒性强）
        return isinstance(instance, possible_types)


class Router(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts):
        super(Router, self).__init__()
        self.fc1 = nn.Linear(input_dim, 2 * hidden_dim)
        self.fc2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
        self.fc_new = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
        # self.fc_new_2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
        self.fc3 = nn.Linear(2 * hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, num_experts)
        self.activation = nn.ReLU()
        print(
            f"Router has {sum(param.numel() for param in self.parameters() if param.requires_grad)} variables requiring grads."
        )
        # self.ln1=nn.LayerNorm(2 * hidden_dim)
        # self.ln2=nn.LayerNorm(2 * hidden_dim)
        # self.ln3=nn.LayerNorm(2 * hidden_dim)
        # self.ln4=nn.LayerNorm(1 * hidden_dim)
        # self.dropout = nn.Dropout(0.1)
        # print('Initializing Layer Norm')
        # print('Initializing Dropout')

    def forward(self, x):
        # h = self.activation(self.ln1(self.fc1(x)))
        # # h = self.dropout(h)
        # h = self.activation(self.ln2(self.fc2(h)))
        # # h = self.dropout(h)
        # h = self.activation(self.ln3(self.fc_new(h)))
        # # h = self.dropout(h)
        # h = self.activation(self.ln4(self.fc3(h)))
        # h = self.dropout(h)
        h = self.activation(self.fc1(x))
        # h = self.dropout(h)
        h = self.activation(self.fc2(h))
        # h = self.dropout(h)
        h = self.activation(self.fc_new(h))
        # h = self.dropout(h)
        # h = self.activation(self.fc_new_2(h))
        h = self.activation(self.fc3(h))
        # h = self.dropout(h)
        return F.softmax(self.fc4(h), dim=-1)  # self.fc2(h)

class MultiDist(torchd.Distribution):
    """
    Router 加权的 mixture 分布。
    用于组合多个 MLP 的分布输出，支持标准分布 API。
    """
    arg_constraints = {}
    has_rsample = True  # 支持重参数化采样

    def __init__(self, dists, weights):
        """
        Args:
            dists: List[torch.distributions.Distribution]
            weights: Tensor[..., num_experts], softmax 权重（最后一维对应专家）
        """
        self.dists = dists
        self.weights = weights  # [..., num_experts]
        self._batch_shape = torch.Size(weights.shape[:-1])
        # print(f"dists[0]:{dists[0]},dists[0].mean:{dists[0].mean}")
        try:
            self._event_shape = dists[0].mean().shape[len(self._batch_shape):]
        except: #tools的dist其实写的不是很统一，有些mean本身就已经是一个property了
            self._event_shape = dists[0].mean.shape[len(self._batch_shape):]
        super().__init__(self._batch_shape, self._event_shape)

    # === 基本属性 ===
    def mean(self):
        means = torch.stack([d.mean() for d in self.dists], dim=-1)
        return (means * self.weights.unsqueeze(-2)).sum(dim=-1)

    def mode(self):
        if hasattr(self.dists[0], "mode"):
            modes = torch.stack([d.mode() for d in self.dists], dim=-1)
            return (modes * self.weights.unsqueeze(-2)).sum(dim=-1)
        return self.mean()

    def variance(self):
        """Var = Σ w_i * (σ_i² + μ_i²) - (Σ w_i μ_i)²"""
        means = torch.stack([d.mean() for d in self.dists], dim=-1)
        vars_ = torch.stack([d.variance() for d in self.dists], dim=-1)
        weighted_mean = (means * self.weights.unsqueeze(-2)).sum(dim=-1)
        weighted_sq = (self.weights.unsqueeze(-2) * (vars_ + means ** 2)).sum(dim=-1)
        return weighted_sq - weighted_mean ** 2

    # === 采样 ===
    def sample(self, sample_shape=torch.Size()):
        samples = [d.sample(sample_shape) for d in self.dists]
        samples = torch.stack(samples, dim=-1)
        return (samples * self.weights.unsqueeze(-2)).sum(dim=-1)

    def rsample(self, sample_shape=torch.Size()):
        """重参数化采样"""
        samples = [d.rsample(sample_shape) for d in self.dists]
        samples = torch.stack(samples, dim=-1)
        return (samples * self.weights.unsqueeze(-2)).sum(dim=-1)

    # === log_prob ===
    def log_prob(self, x):
        """log p(x) = log(Σ w_i * p_i(x))"""
        logps = torch.stack([d.log_prob(x) for d in self.dists], dim=-1)
        max_logp, _ = torch.max(logps, dim=-1, keepdim=True)
        weighted = self.weights * torch.exp(logps - max_logp)
        return (torch.log(weighted.sum(dim=-1) + 1e-8) + max_logp.squeeze(-1))

    # === 熵 ===
    def entropy(self):
        """H(p) ≈ Σ_i w_i H(p_i) + H(w)"""
        entropies = torch.stack([d.entropy() for d in self.dists], dim=-1)
        weighted_entropy = (self.weights * entropies).sum(dim=-1)
        entropy_weights = -(self.weights * torch.log(self.weights + 1e-8)).sum(dim=-1)
        return weighted_entropy + entropy_weights

    # === KL Divergence ===
    def kl_divergence(self, other):
        """KL(p‖q)，近似为 Σ w_i KL(p_i‖q_i)"""
        assert isinstance(other, MultiDist)
        kls = torch.stack([
            torchd.kl_divergence(p, q) for p, q in zip(self.dists, other.dists)
        ], dim=-1)
        return (self.weights * kls).sum(dim=-1)

    def __repr__(self):
        return f"MultiDist(num_experts={len(self.dists)}, batch_shape={self._batch_shape}, event_shape={self._event_shape})"

class RandomActor:
    def __init__(self, act_space, config):
        if hasattr(act_space, "discrete"):
            self.random_actor = tools.OneHotDist(
                torch.zeros(config.num_actions).repeat(1, 1)
            )
        else:
            self.random_actor = torchd.independent.Independent(
                UniformWithMode(
                    torch.tensor(act_space.low).repeat(1, 1),
                    torch.tensor(act_space.high).repeat(1, 1),
                ),
                1,
            )

    def __call__(self, *args, **kwargs):
        return self.random_actor


class UniformWithMode(torchd.uniform.Uniform):
    """扩展Uniform分布，添加mode()方法（默认返回区间中点）"""
    def mode(self):
        # broadcast_all确保low/high维度匹配（与原Uniform逻辑一致）
        low, high = broadcast_all(self.low, self.high)
        return (low + high) / 2.0  # 可改为low或high

#[todo] end
class RSSM(nn.Module):
    def __init__(
        self,
        stoch=30,
        deter=200,
        hidden=200,
        rec_depth=1,
        discrete=False,
        act="SiLU",
        norm=True,
        mean_act="none",
        std_act="softplus",
        min_std=0.1,
        unimix_ratio=0.01,
        initial="learned",
        num_actions=None,
        embed=None,
        device=None,
        config=None,#[todo]
    ):
        super(RSSM, self).__init__()
        self._stoch = stoch
        self._deter = deter
        self._hidden = hidden
        self._min_std = min_std
        self._rec_depth = rec_depth
        self._discrete = discrete
        act = getattr(torch.nn, act)
        self._mean_act = mean_act
        self._std_act = std_act
        self._unimix_ratio = unimix_ratio
        self._initial = initial
        self._num_actions = num_actions
        self._embed = embed
        self._device = device
        # [todo] start
        self.with_expert = config.wm_with_expert
        self.has_moe = False
        self.moe_type = None
        self.has_expert = False
        if config.expert_nums == -1:
            self.expert_nums = len(config.train_env_name_list)
        else:
            self.expert_nums = config.expert_nums
        self._img_out_layers_moe = None
        self._obs_out_layers_moe = None
        self._imgs_stat_layer_moe = None
        self._obs_stat_layer_moe = None
        if config.use_dpmm:
            self.dpmm = _make_bnpModel(config, device=device)
            self.dpmm.load_model(config.dpmm["save_dir"])
        else:
            self.dpmm = None
        # self.has_moe = False
        self._config = config
        # [todo] end

        inp_layers = []
        if self._discrete:
            inp_dim = self._stoch * self._discrete + num_actions
        else:
            inp_dim = self._stoch + num_actions
        inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
        if norm:
            inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
        inp_layers.append(act())
        self._img_in_layers = nn.Sequential(*inp_layers)
        self._img_in_layers.apply(tools.weight_init)
        self._cell = GRUCell(self._hidden, self._deter, norm=norm)
        self._cell.apply(tools.weight_init)

        img_out_layers = []
        inp_dim = self._deter
        img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
        if norm:
            img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
        img_out_layers.append(act())
        self._img_out_layers = nn.Sequential(*img_out_layers)
        self._img_out_layers.apply(tools.weight_init)

        obs_out_layers = []
        inp_dim = self._deter + self._embed
        obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
        if norm:
            obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
        obs_out_layers.append(act())
        self._obs_out_layers = nn.Sequential(*obs_out_layers)
        self._obs_out_layers.apply(tools.weight_init)

        if self._discrete:
            self._imgs_stat_layer = nn.Linear(
                self._hidden, self._stoch * self._discrete
            )
            self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
            self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
            self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
        else:
            self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
            self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
            self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
            self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))


        if self._initial == "learned":
            self.W = torch.nn.Parameter(
                torch.zeros((1, self._deter), device=torch.device(self._device)),
                requires_grad=True,
            )



    #[todo] start

    def init_moe_with_rssm_expert(self, name, expert, expert_nums=None):
        """
        This method will init MoE which can replace original layer.(主要用于RSSM type的expert)
        e.g. '_img_out_layers_moe' can replace '_img_out_layers'
        """
        if expert_nums is None:
            expert_nums = len(expert)
        if name == "img_out":
            self._img_out_layers_moe = MoE(self._deter, None, expert, multi_stage=self._config.multi_stage, config=self._config).to(self._device)
        elif name == "obs_out":
            self._obs_out_layers_moe = MoE(self._deter + self._embed, None, expert, multi_stage=self._config.multi_stage, config=self._config).to(self._device)
        elif name == "img_stat":
            self._imgs_stat_layer_moe = MoE(self._hidden, None, expert, multi_stage=self._config.multi_stage, config=self._config).to(self._device)
        elif name == "obs_stat":
            self._obs_stat_layer_moe = MoE(self._hidden, None, expert, multi_stage=self._config.multi_stage, config=self._config).to(self._device)
        else:
            raise NotImplementedError(f"Wrong name in initing MoE")
        if self._config.expert_type == "RSSM":
            self.moe_type = "replace"
        elif self._config.expert_type == "RSSM_append":
            self.moe_type = "rssm_append" #以残差的形式添加
        else:
            raise NotImplementedError("Not implemented expert type")
        # 这种初始化方式的has_moe的设置放在调用的地方
        # if self._img_out_layers_moe is not None and self._obs_out_layers_moe is not None and self._imgs_stat_layer_moe is not None and self._obs_stat_layer_moe is not None:
        #     self.has_moe = True


    def get_moe_params(self):
        moe_modules = [self._img_out_layers_moe, self._obs_out_layers_moe, self._imgs_stat_layer_moe, self._obs_stat_layer_moe]
        # 批量收集所有新增参数
        new_params = []
        for module in moe_modules:
            if module is not None:
                new_params.extend(list(module.parameters()))
        return new_params

    def freeze_backbone(self):
        """
        冻结expert中涉及到模块的对应模块所有参数
        """
        for param in self._img_out_layers.parameters():
            param.requires_grad = False
        for param in self._obs_out_layers.parameters():
            param.requires_grad = False
        for param in self._imgs_stat_layer.parameters():
            param.requires_grad = False
        for param in self._obs_stat_layer.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
        """
        解冻expert中涉及到模块的对应模块所有参数
        """
        for param in self._img_out_layers.parameters():
            param.requires_grad = True
        for param in self._obs_out_layers.parameters():
            param.requires_grad = True
        for param in self._imgs_stat_layer.parameters():
            param.requires_grad = True
        for param in self._obs_stat_layer.parameters():
            param.requires_grad = True


    #[todo] end

    def initial(self, batch_size, env_name=None):
        deter = torch.zeros(batch_size, self._deter, device=self._device)
        if self._discrete:
            state = dict(
                logit=torch.zeros(
                    [batch_size, self._stoch, self._discrete], device=self._device
                ),
                stoch=torch.zeros(
                    [batch_size, self._stoch, self._discrete], device=self._device
                ),
                deter=deter,
            )
        else:
            state = dict(
                mean=torch.zeros([batch_size, self._stoch], device=self._device),
                std=torch.zeros([batch_size, self._stoch], device=self._device),
                stoch=torch.zeros([batch_size, self._stoch], device=self._device),
                deter=deter,
            )
        if self._initial == "zeros":
            return state
        elif self._initial == "learned":
            state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
            state["stoch"] = self.get_stoch(state["deter"], env_name=env_name) #[todo]
            return state
        else:
            raise NotImplementedError(self._initial)

    def observe(self, embed, action, is_first, state=None, env_name=None, step=None, total_update=None): #[todo]
        swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
        # (batch, time, ch) -> (time, batch, ch)
        embed, action, is_first = swap(embed), swap(action), swap(is_first)
        # if action.shape[0] != len(self._config.train_env_name_list):
        # print(f"action.shape in observe:{action.shape}")
        # prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
        post, prior = tools.static_scan( #[todo]
            lambda prev_state, prev_act, embed, is_first, env_name, step, total_update: self.obs_step(
                prev_state[0], prev_act, embed, is_first, env_name=env_name, step=step, total_update=total_update
            ),
            (action, embed, is_first, env_name, step, total_update),
            (state, state),
        )

        # (batch, time, stoch, discrete_num) -> (batch, time, stoch, discrete_num)
        post = {k: swap(v) for k, v in post.items()}
        prior = {k: swap(v) for k, v in prior.items()}
        return post, prior

    def imagine_with_action(self, action, state, env_name=None): #[todo]
        swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
        assert isinstance(state, dict), state
        action = swap(action)
        prior = tools.static_scan(self.img_step, [action, env_name], state) #[todo]
        prior = prior[0]
        prior = {k: swap(v) for k, v in prior.items()}
        return prior

    def get_feat(self, state):
        stoch = state["stoch"]
        if self._discrete:
            shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete]
            stoch = stoch.reshape(shape)
        return torch.cat([stoch, state["deter"]], -1)

    def get_dist(self, state, dtype=None):
        if self._discrete:
            logit = state["logit"]
            dist = torchd.independent.Independent(
                tools.OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1
            )
        else:
            mean, std = state["mean"], state["std"]
            dist = tools.ContDist(
                torchd.independent.Independent(torchd.normal.Normal(mean, std), 1)
            )
        return dist

    def obs_step(self, prev_state, prev_action, embed, is_first, env_name=None, step=None, total_update=None, sample=True): #[todo]
        # print(f"step in obs_step:{step}")
        # initialize all prev_state
        if prev_state == None or torch.sum(is_first) == len(is_first):
            # if len(is_first) != len(self._config.train_env_name_list):
            #     print(f"Flag1 in obs_step")
            prev_state = self.initial(len(is_first), env_name=env_name) #[todo]
            prev_action = torch.zeros(
                (len(is_first), self._num_actions), device=self._device
            )
        # overwrite the prev_state only where is_first=True
        elif torch.sum(is_first) > 0:
            # if len(is_first) != len(self._config.train_env_name_list):
            #     print(f"Flag2 in obs_step")
            is_first = is_first[:, None]
            prev_action *= 1.0 - is_first
            init_state = self.initial(len(is_first), env_name=env_name) #[todo]
            for key, val in prev_state.items():
                is_first_r = torch.reshape(
                    is_first,
                    is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
                )
                prev_state[key] = (
                    val * (1.0 - is_first_r) + init_state[key] * is_first_r
                )
        # if prev_action.shape[0] != len(self._config.train_env_name_list):
        # print(f"prev_action.shape in obs_step:{prev_action.shape}")

        prior = self.img_step(prev_state, prev_action, env_name=env_name, step=step, total_update=total_update) #[todo]
        # if prior['deter'].shape[0] != len(self._config.train_env_name_list):
        #     print(f"prior['deter'].shape in obs_step:{prior['deter'].shape}")
        x = torch.cat([prior["deter"], embed], -1)
        # (batch_size, prior_deter + embed) -> (batch_size, hidden)
        # x = self._obs_out_layers(x)
        #[todo] start
        if self._obs_out_layers_moe is not None:
            if self.moe_type == "append":
                pass
            elif self.moe_type == "rssm_append":
                moe_output = self._obs_out_layers_moe(x, env_name=env_name)
                x = self._obs_out_layers(x)
                x = x + moe_output
            elif self.moe_type == "rssm":
                x = self._obs_out_layers_moe(x, env_name=env_name)
            else:
                raise NotImplementedError("Not implemented moe type")
        else:
            x = self._obs_out_layers(x)
        #[todo] end
        # (batch_size, hidden) -> (batch_size, stoch, discrete_num)
        stats = self._suff_stats_layer("obs", x, env_name=env_name) #[todo]
        if sample:
            stoch = self.get_dist(stats).sample()
        else:
            stoch = self.get_dist(stats).mode()
        post = {"stoch": stoch, "deter": prior["deter"], **stats}
        return post, prior

    def img_step(self, prev_state, prev_action, env_name=None, step=None, total_update=None, sample=True): #[todo]
        # (batch, stoch, discrete_num)
        prev_stoch = prev_state["stoch"]
        if self._discrete:
            shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
            # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
            prev_stoch = prev_stoch.reshape(shape)
        # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
        x = torch.cat([prev_stoch, prev_action], -1)
        # (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
        x = self._img_in_layers(x)
        for _ in range(self._rec_depth):  # rec depth is not correctly implemented
            deter = prev_state["deter"]
            # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
            x, deter = self._cell(x, [deter])
            deter = deter[0]  # Keras wraps the state in a list.
        # (batch, deter) -> (batch, hidden)
        # x = self._img_out_layers(x)
        #[todo] start
        if self._img_out_layers_moe is not None:
            if self.moe_type == "append":
                pass
            elif self.moe_type == "rssm_append":
                moe_output = self._img_out_layers_moe(x, env_name=env_name)
                x = self._img_out_layers(x)
                x = x + moe_output
            elif self.moe_type == "rssm":
                x = self._img_out_layers_moe(x, env_name=env_name)
            else:
                raise NotImplementedError("Not implemented moe type")
        else:
            x = self._img_out_layers(x)

        if self.dpmm is not None:
            # print(f"step before get_information:{step}")
            information_embed = self.dpmm.get_information(x, step=step, total_update=total_update)
            if information_embed is not None:
                x = x + information_embed
        #[todo] end
        # (batch, hidden) -> (batch_size, stoch, discrete_num)
        stats = self._suff_stats_layer("ims", x, env_name=env_name) #[todo]
        if sample:
            stoch = self.get_dist(stats).sample()
        else:
            stoch = self.get_dist(stats).mode()
        prior = {"stoch": stoch, "deter": deter, **stats}
        return prior

    def get_stoch(self, deter, env_name=None): #[todo]
        # x = self._img_out_layers(deter)
        #[todo] start
        if self._img_out_layers_moe is not None:
            if self.moe_type == "append":
                pass
            elif self.moe_type == "rssm_append":
                moe_output = self._img_out_layers_moe(deter, env_name=env_name)
                x = self._img_out_layers(deter)
                x = x + moe_output
            elif self.moe_type == "rssm":
                x = self._img_out_layers_moe(deter, env_name=env_name)
            else:
                raise NotImplementedError("Not implemented moe type")
        else:
            x = self._img_out_layers(deter)
        #[todo] end
        stats = self._suff_stats_layer("ims", x, env_name=env_name) #[todo]
        dist = self.get_dist(stats)
        return dist.mode()

    def _suff_stats_layer(self, name, x, env_name=None): #[todo]
        if self._discrete:
            if name == "ims":
                # x = self._imgs_stat_layer(x)
                #[todo] start
                if self._imgs_stat_layer_moe is not None:
                    if self.moe_type == "append":
                        pass
                    elif self.moe_type == "rssm_append":
                        moe_output = self._imgs_stat_layer_moe(x, env_name=env_name)
                        x = self._imgs_stat_layer(x)
                        x = x + moe_output
                    elif self.moe_type == "rssm":
                        x = self._imgs_stat_layer_moe(x, env_name=env_name)
                    else:
                        raise NotImplementedError("Not implemented moe type")
                else:
                    x = self._imgs_stat_layer(x)
                #[todo] end
            elif name == "obs":
                # x = self._obs_stat_layer(x)
                #[todo] start
                if self._obs_stat_layer_moe is not None:
                    if self.moe_type == "append":
                        pass
                    elif self.moe_type == "rssm_append":
                        moe_output = self._obs_stat_layer_moe(x, env_name=env_name)
                        x = self._obs_stat_layer(x)
                        x = x + moe_output
                    elif self.moe_type == "rssm":
                        x = self._obs_stat_layer_moe(x, env_name=env_name)
                    else:
                        raise NotImplementedError("Not implemented moe type")
                else:
                    x = self._obs_stat_layer(x)
                #[todo] end
            else:
                raise NotImplementedError
            logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete])
            return {"logit": logit}
        else:
            if name == "ims":
                # x = self._imgs_stat_layer(x)
                #[todo] start
                if self._imgs_stat_layer_moe is not None:
                    if self.moe_type == "append":
                        pass
                    elif self.moe_type == "rssm_append":
                        moe_output = self._imgs_stat_layer_moe(x, env_name=env_name)
                        x = self._imgs_stat_layer(x)
                        x = x + moe_output
                    elif self.moe_type == "rssm":
                        x = self._imgs_stat_layer_moe(x, env_name=env_name)
                    else:
                        raise NotImplementedError("Not implemented moe type")
                else:
                    x = self._imgs_stat_layer(x)
                #[todo] end
            elif name == "obs":
                # x = self._obs_stat_layer(x)
                #[todo] start
                if self._obs_stat_layer_moe is not None:
                    if self.moe_type == "append":
                        pass
                    elif self.moe_type == "rssm_append":
                        moe_output = self._obs_stat_layer_moe(x, env_name=env_name)
                        x = self._obs_stat_layer(x)
                        x = x + moe_output
                    elif self.moe_type == "rssm":
                        x = self._obs_stat_layer_moe(x, env_name=env_name)
                    else:
                        raise NotImplementedError("Not implemented moe type")
                else:
                    x = self._obs_stat_layer(x)

                #[todo] end
            else:
                raise NotImplementedError
            mean, std = torch.split(x, [self._stoch] * 2, -1)
            mean = {
                "none": lambda: mean,
                "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0),
            }[self._mean_act]()
            std = {
                "softplus": lambda: torch.softplus(std),
                "abs": lambda: torch.abs(std + 1),
                "sigmoid": lambda: torch.sigmoid(std),
                "sigmoid2": lambda: 2 * torch.sigmoid(std / 2),
            }[self._std_act]()
            std = std + self._min_std
            return {"mean": mean, "std": std}

    def kl_loss(self, post, prior, free, dyn_scale, rep_scale):
        kld = torchd.kl.kl_divergence
        dist = lambda x: self.get_dist(x)
        sg = lambda x: {k: v.detach() for k, v in x.items()}

        rep_loss = value = kld(
            dist(post) if self._discrete else dist(post)._dist,
            dist(sg(prior)) if self._discrete else dist(sg(prior))._dist,
        )
        dyn_loss = kld(
            dist(sg(post)) if self._discrete else dist(sg(post))._dist,
            dist(prior) if self._discrete else dist(prior)._dist,
        )
        # this is implemented using maximum at the original repo as the gradients are not backpropagated for the out of limits.
        rep_loss = torch.clip(rep_loss, min=free)
        dyn_loss = torch.clip(dyn_loss, min=free)
        loss = dyn_scale * dyn_loss + rep_scale * rep_loss

        return loss, value, dyn_loss, rep_loss


class MultiEncoder(nn.Module):
    def __init__(
        self,
        shapes,
        mlp_keys,
        cnn_keys,
        act,
        norm,
        cnn_depth,
        kernel_size,
        minres,
        mlp_layers,
        mlp_units,
        symlog_inputs,
    ):
        super(MultiEncoder, self).__init__()
        excluded = ("is_first", "is_last", "is_terminal", "reward")
        shapes = {
            k: v
            for k, v in shapes.items()
            if k not in excluded and not k.startswith("log_")
        }
        self.cnn_shapes = {
            k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
        }
        self.mlp_shapes = {
            k: v
            for k, v in shapes.items()
            if len(v) in (1, 2) and re.match(mlp_keys, k)
        }
        print("Encoder CNN shapes:", self.cnn_shapes)
        print("Encoder MLP shapes:", self.mlp_shapes)

        self.outdim = 0
        if self.cnn_shapes:
            input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
            input_shape = tuple(self.cnn_shapes.values())[0][:2] + (input_ch,)
            self._cnn = ConvEncoder(
                input_shape, cnn_depth, act, norm, kernel_size, minres
            )
            self.outdim += self._cnn.outdim
        if self.mlp_shapes:
            input_size = sum([sum(v) for v in self.mlp_shapes.values()])
            self._mlp = MLP(
                input_size,
                None,
                mlp_layers,
                mlp_units,
                act,
                norm,
                symlog_inputs=symlog_inputs,
                name="Encoder",
            )
            self.outdim += mlp_units

    def forward(self, obs, env_name=None): #[todo]
        outputs = []
        if self.cnn_shapes:
            inputs = torch.cat([obs[k] for k in self.cnn_shapes], -1)
            outputs.append(self._cnn(inputs))
        if self.mlp_shapes:
            inputs = torch.cat([obs[k] for k in self.mlp_shapes], -1)
            outputs.append(self._mlp(inputs))
        outputs = torch.cat(outputs, -1)
        return outputs


class MultiDecoder(nn.Module):
    def __init__(
        self,
        feat_size,
        shapes,
        mlp_keys,
        cnn_keys,
        act,
        norm,
        cnn_depth,
        kernel_size,
        minres,
        mlp_layers,
        mlp_units,
        cnn_sigmoid,
        image_dist,
        vector_dist,
        outscale,
    ):
        super(MultiDecoder, self).__init__()
        excluded = ("is_first", "is_last", "is_terminal")
        shapes = {k: v for k, v in shapes.items() if k not in excluded}
        self.cnn_shapes = {
            k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
        }
        self.mlp_shapes = {
            k: v
            for k, v in shapes.items()
            if len(v) in (1, 2) and re.match(mlp_keys, k)
        }
        print("Decoder CNN shapes:", self.cnn_shapes)
        print("Decoder MLP shapes:", self.mlp_shapes)

        if self.cnn_shapes:
            some_shape = list(self.cnn_shapes.values())[0]
            shape = (sum(x[-1] for x in self.cnn_shapes.values()),) + some_shape[:-1]
            self._cnn = ConvDecoder(
                feat_size,
                shape,
                cnn_depth,
                act,
                norm,
                kernel_size,
                minres,
                outscale=outscale,
                cnn_sigmoid=cnn_sigmoid,
            )
        if self.mlp_shapes:
            self._mlp = MLP(
                feat_size,
                self.mlp_shapes,
                mlp_layers,
                mlp_units,
                act,
                norm,
                vector_dist,
                outscale=outscale,
                name="Decoder",
            )
        self._image_dist = image_dist

    def forward(self, features, env_name=None): #[todo] env_name占位
        dists = {}
        if self.cnn_shapes:
            feat = features
            outputs = self._cnn(feat)
            split_sizes = [v[-1] for v in self.cnn_shapes.values()]
            outputs = torch.split(outputs, split_sizes, -1)
            dists.update(
                {
                    key: self._make_image_dist(output)
                    for key, output in zip(self.cnn_shapes.keys(), outputs)
                }
            )
        if self.mlp_shapes:
            dists.update(self._mlp(features))
        return dists

    def _make_image_dist(self, mean):
        if self._image_dist == "normal":
            return tools.ContDist(
                torchd.independent.Independent(torchd.normal.Normal(mean, 1), 3)
            )
        if self._image_dist == "mse":
            return tools.MSEDist(mean)
        raise NotImplementedError(self._image_dist)


class ConvEncoder(nn.Module):
    def __init__(
        self,
        input_shape,
        depth=32,
        act="SiLU",
        norm=True,
        kernel_size=4,
        minres=4,
    ):
        super(ConvEncoder, self).__init__()
        act = getattr(torch.nn, act)
        h, w, input_ch = input_shape
        stages = int(np.log2(h) - np.log2(minres))
        in_dim = input_ch
        out_dim = depth
        layers = []
        for i in range(stages):
            layers.append(
                Conv2dSamePad(
                    in_channels=in_dim,
                    out_channels=out_dim,
                    kernel_size=kernel_size,
                    stride=2,
                    bias=False,
                )
            )
            if norm:
                layers.append(ImgChLayerNorm(out_dim))
            layers.append(act())
            in_dim = out_dim
            out_dim *= 2
            h, w = h // 2, w // 2

        self.outdim = out_dim // 2 * h * w
        self.layers = nn.Sequential(*layers)
        self.layers.apply(tools.weight_init)

    def forward(self, obs):
        obs -= 0.5
        # (batch, time, h, w, ch) -> (batch * time, h, w, ch)
        x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
        # (batch * time, h, w, ch) -> (batch * time, ch, h, w)
        x = x.permute(0, 3, 1, 2)
        x = self.layers(x)
        # (batch * time, ...) -> (batch * time, -1)
        x = x.reshape([x.shape[0], np.prod(x.shape[1:])])
        # (batch * time, -1) -> (batch, time, -1)
        return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]])


class ConvDecoder(nn.Module):
    def __init__(
        self,
        feat_size,
        shape=(3, 64, 64),
        depth=32,
        act=nn.ELU,
        norm=True,
        kernel_size=4,
        minres=4,
        outscale=1.0,
        cnn_sigmoid=False,
    ):
        super(ConvDecoder, self).__init__()
        act = getattr(torch.nn, act)
        self._shape = shape
        self._cnn_sigmoid = cnn_sigmoid
        layer_num = int(np.log2(shape[1]) - np.log2(minres))
        self._minres = minres
        out_ch = minres**2 * depth * 2 ** (layer_num - 1)
        self._embed_size = out_ch

        self._linear_layer = nn.Linear(feat_size, out_ch)
        self._linear_layer.apply(tools.uniform_weight_init(outscale))
        in_dim = out_ch // (minres**2)
        out_dim = in_dim // 2

        layers = []
        h, w = minres, minres
        for i in range(layer_num):
            bias = False
            if i == layer_num - 1:
                out_dim = self._shape[0]
                act = False
                bias = True
                norm = False

            if i != 0:
                in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
            pad_h, outpad_h = self.calc_same_pad(k=kernel_size, s=2, d=1)
            pad_w, outpad_w = self.calc_same_pad(k=kernel_size, s=2, d=1)
            layers.append(
                nn.ConvTranspose2d(
                    in_dim,
                    out_dim,
                    kernel_size,
                    2,
                    padding=(pad_h, pad_w),
                    output_padding=(outpad_h, outpad_w),
                    bias=bias,
                )
            )
            if norm:
                layers.append(ImgChLayerNorm(out_dim))
            if act:
                layers.append(act())
            in_dim = out_dim
            out_dim //= 2
            h, w = h * 2, w * 2
        [m.apply(tools.weight_init) for m in layers[:-1]]
        layers[-1].apply(tools.uniform_weight_init(outscale))
        self.layers = nn.Sequential(*layers)

    def calc_same_pad(self, k, s, d):
        val = d * (k - 1) - s + 1
        pad = math.ceil(val / 2)
        outpad = pad * 2 - val
        return pad, outpad

    def forward(self, features, dtype=None):
        x = self._linear_layer(features)
        # (batch, time, -1) -> (batch * time, h, w, ch)
        x = x.reshape(
            [-1, self._minres, self._minres, self._embed_size // self._minres**2]
        )
        # (batch, time, -1) -> (batch * time, ch, h, w)
        x = x.permute(0, 3, 1, 2)
        x = self.layers(x)
        # (batch, time, -1) -> (batch, time, ch, h, w)
        mean = x.reshape(features.shape[:-1] + self._shape)
        # (batch, time, ch, h, w) -> (batch, time, h, w, ch)
        mean = mean.permute(0, 1, 3, 4, 2)
        if self._cnn_sigmoid:
            mean = F.sigmoid(mean)
        else:
            mean += 0.5
        return mean


class MLP(nn.Module):
    def __init__(
        self,
        inp_dim,
        shape,
        layers,
        units,
        act="SiLU",
        norm=True,
        dist="normal",
        std=1.0,
        min_std=0.1,
        max_std=1.0,
        absmax=None,
        temp=0.1,
        unimix_ratio=0.01,
        outscale=1.0,
        symlog_inputs=False,
        device="cuda",
        name="NoName",
    ):
        super(MLP, self).__init__()
        self._shape = (shape,) if isinstance(shape, int) else shape
        if self._shape is not None and len(self._shape) == 0:
            self._shape = (1,)
        act = getattr(torch.nn, act)
        self._dist = dist
        self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
        self._min_std = min_std
        self._max_std = max_std
        self._absmax = absmax
        self._temp = temp
        self._unimix_ratio = unimix_ratio
        self._symlog_inputs = symlog_inputs
        self._device = device

        self.layers = nn.Sequential()
        for i in range(layers):
            self.layers.add_module(
                f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
            )
            if norm:
                self.layers.add_module(
                    f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
                )
            self.layers.add_module(f"{name}_act{i}", act())
            if i == 0:
                inp_dim = units
        self.layers.apply(tools.weight_init)

        if isinstance(self._shape, dict):
            self.mean_layer = nn.ModuleDict()
            for name, shape in self._shape.items():
                self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
            self.mean_layer.apply(tools.uniform_weight_init(outscale))
            if self._std == "learned":
                assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
                self.std_layer = nn.ModuleDict()
                for name, shape in self._shape.items():
                    self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
                self.std_layer.apply(tools.uniform_weight_init(outscale))
        elif self._shape is not None:
            self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
            self.mean_layer.apply(tools.uniform_weight_init(outscale))
            if self._std == "learned":
                assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
                self.std_layer = nn.Linear(units, np.prod(self._shape))
                self.std_layer.apply(tools.uniform_weight_init(outscale))

    def forward(self, features, dtype=None, env_name=None): #[todo] env_name占位
        x = features
        if self._symlog_inputs:
            x = tools.symlog(x)
        out = self.layers(x)
        # Used for encoder output
        if self._shape is None:
            return out
        if isinstance(self._shape, dict):
            dists = {}
            for name, shape in self._shape.items():
                mean = self.mean_layer[name](out)
                if self._std == "learned":
                    std = self.std_layer[name](out)
                else:
                    std = self._std
                dists.update({name: self.dist(self._dist, mean, std, shape)})
            return dists
        else:
            mean = self.mean_layer(out)
            if self._std == "learned":
                std = self.std_layer(out)
            else:
                std = self._std
            return self.dist(self._dist, mean, std, self._shape)

    def dist(self, dist, mean, std, shape):
        if dist == "tanh_normal":
            mean = torch.tanh(mean)
            std = F.softplus(std) + self._min_std
            dist = torchd.normal.Normal(mean, std)
            dist = torchd.transformed_distribution.TransformedDistribution(
                dist, tools.TanhBijector()
            )
            dist = torchd.independent.Independent(dist, 1)
            dist = tools.SampleDist(dist)
        elif dist == "normal":
            std = (self._max_std - self._min_std) * torch.sigmoid(
                std + 2.0
            ) + self._min_std
            dist = torchd.normal.Normal(torch.tanh(mean), std)
            dist = tools.ContDist(
                torchd.independent.Independent(dist, 1), absmax=self._absmax
            )
        elif dist == "normal_std_fixed":
            dist = torchd.normal.Normal(mean, self._std)
            dist = tools.ContDist(
                torchd.independent.Independent(dist, 1), absmax=self._absmax
            )
        elif dist == "trunc_normal":
            mean = torch.tanh(mean)
            std = 2 * torch.sigmoid(std / 2) + self._min_std
            dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
            dist = tools.ContDist(
                torchd.independent.Independent(dist, 1), absmax=self._absmax
            )
        elif dist == "onehot":
            dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
        elif dist == "onehot_gumble":
            dist = tools.ContDist(
                torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
            )
        elif dist == "huber":
            dist = tools.ContDist(
                torchd.independent.Independent(
                    tools.UnnormalizedHuber(mean, std, 1.0),
                    len(shape),
                    absmax=self._absmax,
                )
            )
        elif dist == "binary":
            dist = tools.Bernoulli(
                torchd.independent.Independent(
                    torchd.bernoulli.Bernoulli(logits=mean), len(shape)
                )
            )
        elif dist == "symlog_disc":
            dist = tools.DiscDist(logits=mean, device=self._device)
        elif dist == "symlog_mse":
            dist = tools.SymlogDist(mean)
        else:
            raise NotImplementedError(dist)
        return dist


class GRUCell(nn.Module):
    def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
        super(GRUCell, self).__init__()
        self._inp_size = inp_size
        self._size = size
        self._act = act
        self._update_bias = update_bias
        self.layers = nn.Sequential()
        self.layers.add_module(
            "GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
        )
        if norm:
            self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))

    @property
    def state_size(self):
        return self._size

    def forward(self, inputs, state):
        state = state[0]  # Keras wraps the state in a list.
        parts = self.layers(torch.cat([inputs, state], -1))
        reset, cand, update = torch.split(parts, [self._size] * 3, -1)
        reset = torch.sigmoid(reset)
        cand = self._act(reset * cand)
        update = torch.sigmoid(update + self._update_bias)
        output = update * cand + (1 - update) * state
        return output, [output]


class Conv2dSamePad(torch.nn.Conv2d):
    def calc_same_pad(self, i, k, s, d):
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x):
        ih, iw = x.size()[-2:]
        pad_h = self.calc_same_pad(
            i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]
        )
        pad_w = self.calc_same_pad(
            i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]
        )

        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
            )

        ret = F.conv2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return ret


class ImgChLayerNorm(nn.Module):
    def __init__(self, ch, eps=1e-03):
        super(ImgChLayerNorm, self).__init__()
        self.norm = torch.nn.LayerNorm(ch, eps=eps)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        return x
