import torch
import torch.nn as nn
from torch.distributions import Independent, kl_divergence
from tensordict import TensorDict
import numpy as np

from typing import Dict, List, Tuple, Callable

from utils.tools import n2t, scan, RequiresGrad, Optimizer, latent_to_input, get_st_onehot
from utils.networks import MLP, GRU, CNN, Transformer
from utils.output_head import MLPHead, FigureHead
from actor_critic.actor import Actor
import elements

class MAWorldModel(nn.Module):
    _name = "world_model"
    def __init__(self, config, obs_shape, n_actions, n_agents, device):
        super().__init__()
        self.config = config
        self.obs_shape = obs_shape
        self.n_actions = n_actions
        self.n_agents = n_agents
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        # world model
        self.encoder = ObsEncoder(config, obs_shape, device)
        self.dynamics = RSSM(config, n_actions, n_agents, device)
        # obs predictor
        obs_predictor_in_dim = (
            config.world_model.rssm.deterministic_dim
            + config.world_model.rssm.stochastic_dim
            * config.world_model.rssm.classes
        )
        if len(obs_shape) == 1:
            self.obs_predictor = MLPHead(
                in_dim=obs_predictor_in_dim,
                hidden_dim=config.world_model.rssm.obs_predictor.hidden_dim,
                hidden_layers=config.world_model.rssm.obs_predictor.hidden_layers,
                out_dim=obs_shape[0],
                act=config.world_model.rssm.obs_predictor.act,
                use_layernorm=config.world_model.rssm.obs_predictor.use_layernorm,
                output=config.world_model.rssm.obs_predictor.output,
                device=device,
            )
        elif len(obs_shape) == 3:
            self.obs_predictor = FigureHead(
                in_dim=obs_predictor_in_dim,
                out_shape=obs_shape,
                depth=config.world_model.rssm.obs_predictor.depth,
                mults=config.world_model.rssm.obs_predictor.mults,
                kernel=config.world_model.rssm.obs_predictor.kernel,
                act=config.world_model.rssm.obs_predictor.act,
                use_layernorm=config.world_model.rssm.obs_predictor.use_layernorm,
                device=device,
            )
        else:
            raise NotImplementedError

        # other predictors
        self.global_agent_embedding_transformer = Transformer(
            d_model=config.world_model.rssm.deterministic_dim,
            nhead=config.world_model.rssm.global_agent_embedding_transformer.nhead,
            num_layers=config.world_model.rssm.global_agent_embedding_transformer.num_layers,
            act=config.world_model.rssm.global_agent_embedding_transformer.activation,
            norm_first=config.world_model.rssm.global_agent_embedding_transformer.norm_first,
            device=device,
        )
        self.reward_predictor = MLPHead(
            in_dim=config.world_model.rssm.deterministic_dim,
            hidden_dim=config.world_model.rssm.reward_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.reward_predictor.hidden_layers,
            out_dim=1,
            act=config.world_model.rssm.reward_predictor.act,
            use_layernorm=config.world_model.rssm.reward_predictor.use_layernorm,
            output=config.world_model.rssm.reward_predictor.output,
            out_scale=config.world_model.rssm.reward_predictor.out_scale,
            device=device,
        )
        self.cont_predictor = MLPHead(
            in_dim=config.world_model.rssm.deterministic_dim,
            hidden_dim=config.world_model.rssm.cont_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.cont_predictor.hidden_layers,
            out_dim=1,
            act=config.world_model.rssm.cont_predictor.act,
            use_layernorm=config.world_model.rssm.cont_predictor.use_layernorm,
            output=config.world_model.rssm.cont_predictor.output,
            device=device,
        ) if config.world_model.rssm.use_cont_predictor else None
        self.act_mask_predictor = MLPHead(
            in_dim=config.world_model.rssm.deterministic_dim,
            hidden_dim=config.world_model.rssm.act_mask_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.act_mask_predictor.hidden_layers,
            out_dim=n_actions,
            act=config.world_model.rssm.act_mask_predictor.act,
            use_layernorm=config.world_model.rssm.act_mask_predictor.use_layernorm,
            output=config.world_model.rssm.act_mask_predictor.output,
            device=device,
        ) if config.world_model.rssm.use_act_mask_predictor else None

        # context manager for gradient control
        self._requires_grad = RequiresGrad(self)

        # optimizer
        self._optim = Optimizer(
            name=self._name,
            parameters=self.parameters(),
            lr=config.train.optim.world_model.lr,
            eps=config.train.optim.world_model.eps,
            use_max_grad_norm=config.train.optim.world_model.use_max_grad_norm,
            max_grad_norm=config.train.optim.world_model.max_grad_norm,
        )

    def observe(
        self,
        obs: torch.Tensor | np.ndarray,
        prev_actions: torch.Tensor | np.ndarray,
        is_first: torch.Tensor | np.ndarray,
    ) -> TensorDict:
        obs = n2t(obs, **self.tpdv) if isinstance(obs, np.ndarray) else obs
        prev_actions = n2t(prev_actions, **self.tpdv) if isinstance(prev_actions, np.ndarray) else prev_actions
        is_first = n2t(is_first, device=self.device, dtype=torch.bool) if isinstance(is_first, np.ndarray) else is_first
        # obs.shape = (ts, bs, n_agents, obs_shape)
        # prev_actions.shape = (ts, bs, n_agents, n_actions)
        # is_first.shape = (ts, bs, 1)

        embed = self.encoder(obs)
        post_latent: List[TensorDict] = scan(
            fn=self.observe_step,
            inputs=(
                embed,
                prev_actions,
                is_first,
            ),
            init_state=None,
        )
        return post_latent

    def prep_init_latent_for_imagination(self, samples: Dict[str, np.ndarray]) -> TensorDict:
        obs = n2t(samples["obs"], **self.tpdv)
        prev_actions = n2t(samples["prev_actions"], **self.tpdv)
        rewards = n2t(samples["rewards"], **self.tpdv)
        is_first = n2t(samples["is_first"], device=self.device, dtype=torch.bool)
        agent_mask = n2t(samples["agent_mask"], **self.tpdv)
        avail_actions = n2t(samples["avail_actions"], **self.tpdv) if self.act_mask_predictor is not None else None
        terminated = n2t(samples["is_trailing_absorbing_state"], device=self.device, dtype=torch.bool) if self.cont_predictor is not None else None

        # generate post latent as initial state for imagination
        init_latent: List[TensorDict] = self.observe(
            obs=obs,
            prev_actions=prev_actions,
            is_first=is_first,
        )
        init_latent: TensorDict = torch.stack(init_latent, dim=0)

        # fill in the rest of the initial state
        init_latent["agent_embeddings"] = self.global_agent_embedding_transformer(init_latent["deter"], num_batch_dims=2)
        init_latent["agent_mask"] = agent_mask
        init_latent["rewards"] = rewards
        init_latent["avail_actions"] = avail_actions if avail_actions is not None else None
        init_latent["terminated"] = terminated.unsqueeze(2).repeat(1, 1, self.n_agents, 1) if terminated is not None else None
        init_latent = init_latent.flatten(0, 1)
        return init_latent

    @elements.timer.section("imagination")
    @torch.no_grad()
    def imagine(
        self,
        actors: List[Actor],
        init_latent: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """
        Generate imaginary rollouts given previous latent states and actor policies
        """
        # pre-allocate tensors
        B = init_latent["deter"].shape[0]
        T = self.config.train.imagination_steps
        A = self.n_agents
        agent_embeddings = torch.zeros((T+1, B, A, self.config.world_model.rssm.deterministic_dim), device=self.device)
        deter = torch.zeros((T+1, B, A, self.config.world_model.rssm.deterministic_dim), device=self.device)
        stoch = torch.zeros((T+1, B, A, self.config.world_model.rssm.stochastic_dim, self.config.world_model.rssm.classes), device=self.device)
        rewards = torch.zeros((T+1, B, A, 1), device=self.device)
        actions_env = torch.zeros((T, B, A), device=self.device)
        terminated = torch.zeros((T+1, B, A, 1), device=self.device)
        avail_actions = torch.zeros((T+1, B, A, self.n_actions), device=self.device) if self.act_mask_predictor is not None else None

        # set initial values
        agent_embeddings[0] = init_latent["agent_embeddings"]
        deter[0] = init_latent["deter"]
        stoch[0] = init_latent["stoch"]
        rewards[0] = init_latent["rewards"]
        if self.cont_predictor is not None:
            terminated[0] = init_latent["terminated"] 
        if self.act_mask_predictor is not None:
            avail_actions[0] = init_latent["avail_actions"]

        def imagine_step(step: torch.Tensor, latent: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
            # imagine next state
            actor_outputs: List[TensorDict] = [
                actors[i](
                    latent=latent[:, i],
                    avail_actions=avail_actions[step, :, i] if self.act_mask_predictor is not None else None,
                    evaluation=False,
                )
                for i in range(len(actors))
            ]
            actor_outputs = torch.stack(actor_outputs, dim=1)
            next_latent: TensorDict = self.dynamics.imagine(actor_outputs["actions"], latent)

            # fill in next state
            actions_env[step] = actor_outputs["actions_env"]
            deter[step+1] = next_latent["deter"]
            stoch[step+1] = next_latent["stoch"]
            agent_embeddings[step+1] = self.global_agent_embedding_transformer(deter[step+1])
            rewards[step+1] = self.reward_predictor(agent_embeddings[step+1]).pred()
            if self.cont_predictor is not None:
                terminated[step+1] = self.cont_predictor(agent_embeddings[step+1]).pred()
            if self.act_mask_predictor is not None:
                avail_actions[step+1] = self.act_mask_predictor(agent_embeddings[step+1]).sample()
            return next_latent

        scan(
            fn = imagine_step,
            inputs=(torch.arange(self.config.train.imagination_steps),),
            init_state=init_latent,
        )
        imaginary_transitions: Dict[str, torch.Tensor] = {
            "agent_embeddings": agent_embeddings,
            "deter": deter,
            "stoch": stoch,
            "terminated": terminated,
            "rewards": rewards,
            "actions_env": actions_env,
        }
        if self.act_mask_predictor is not None:
            imaginary_transitions["avail_actions"] = avail_actions
        return imaginary_transitions

    def post_prior_loss(
            self,
            post: torch.Tensor,
            prior: torch.Tensor,
            loss_fn: Callable,
            transform: Callable = lambda x: x,
            free: float = 0.0,
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        # calculate dynamic loss
        dyn_loss = loss_fn(transform(post.detach()), transform(prior))
        dyn_loss = dyn_loss.clamp(min=free)
        # calculate representation loss
        rep_loss = loss_fn(transform(post), transform(prior.detach()))
        rep_loss = rep_loss.clamp(min=free)

        return dyn_loss, rep_loss

    def _latent_dist(self, logits: torch.Tensor, unimix: float = 0.0):
        categorical = get_st_onehot(logits, unimix=unimix)
        independent = Independent(categorical, 1)
        return independent

    @elements.timer.section("update_world_model")
    def update(self, samples: Dict[str, np.ndarray]) -> Dict[str, float]:
        obs = n2t(samples["obs"], **self.tpdv)
        prev_actions = n2t(samples["prev_actions"], **self.tpdv)
        rewards = n2t(samples["rewards"], **self.tpdv)
        is_first = n2t(samples["is_first"], device=self.device, dtype=torch.bool)
        agent_mask = n2t(samples["agent_mask"], **self.tpdv)
        avail_actions = n2t(samples["avail_actions"], **self.tpdv) if self.act_mask_predictor is not None else None
        terminated = n2t(samples["is_trailing_absorbing_state"], device=self.device, dtype=torch.bool) if self.cont_predictor is not None else None
        # obs.shape = (ts, bs, n_agents, obs_shape)
        # prev_actions.shape = (ts, bs, n_agents, n_actions)
        # rewards.shape = (ts, bs, n_agents, 1)
        # is_first.shape = (ts, bs, 1)
        # agent_mask.shape = (ts, bs, n_agents, 1)
        # avail_actions.shape = (ts, bs, n_agents, n_actions)
        # terminated.shape = (ts, bs, 1)
        metrics = {}
        loss = 0

        with self._requires_grad:
            post_latent: List[TensorDict] = self.observe(
                obs=obs,
                prev_actions=prev_actions,
                is_first=is_first,
            )
            # remove the burn-in steps
            if self.config.train.burn_in_length:
                obs = obs[self.config.train.burn_in_length:]
                rewards = rewards[self.config.train.burn_in_length:]
                is_first = is_first[self.config.train.burn_in_length:]
                agent_mask = agent_mask[self.config.train.burn_in_length:]
                post_latent = post_latent[self.config.train.burn_in_length:]
                avail_actions = avail_actions[self.config.train.burn_in_length:] if avail_actions is not None else None
                terminated = terminated[self.config.train.burn_in_length:] if terminated is not None else None
            post_latent: TensorDict = torch.stack(post_latent, dim=0)

            # 1. stochastic latent loss
            prior_stoch: TensorDict = self.dynamics.imagine_stoch_from_deter(post_latent["deter"], num_batch_dims=2)
            dyn_loss, rep_loss = self.post_prior_loss(
                post=post_latent["logits"],
                prior=prior_stoch["logits"],
                loss_fn=kl_divergence,
                transform=lambda x: self._latent_dist(x, unimix=self.config.world_model.rssm.unimix),
                free=self.config.train.free_bits,
            )
            assert dyn_loss.unsqueeze(-1).shape == agent_mask.shape, (dyn_loss.shape, agent_mask.shape)
            assert rep_loss.unsqueeze(-1).shape == agent_mask.shape, (rep_loss.shape, agent_mask.shape)
            dyn_loss = dyn_loss.mean()
            rep_loss = rep_loss.mean()
            loss += self.config.train.stoch_dyn_scale * dyn_loss
            loss += self.config.train.stoch_rep_scale * rep_loss
            metrics.update({"dyn_loss": dyn_loss.item(), "rep_loss": rep_loss.item()})

            # log post and prior latent entropy
            prior_dist = self._latent_dist(prior_stoch["logits"], unimix=self.config.world_model.rssm.unimix)
            post_dist = self._latent_dist(post_latent["logits"], unimix=self.config.world_model.rssm.unimix)
            prior_entropy = prior_dist.entropy().mean()
            post_entropy = post_dist.entropy().mean()

            metrics.update(
                {
                    "post_stoch_entropy": post_entropy.item(),
                    "prior_stoch_entropy": prior_entropy.item(),
                }
            )

            predictor_inputs = latent_to_input(post_latent)
            # 2. obs prediction loss
            obs_output = self.obs_predictor(predictor_inputs)
            obs_loss = obs_output.loss(obs)
            # obs_loss.shape = (ts, bs, n_agents, ...)
            obs_loss = obs_loss.sum(dim=list(range(3, obs.ndim))).unsqueeze(-1)
            assert obs_loss.shape == agent_mask.shape, (obs_loss.shape, agent_mask.shape)
            obs_loss = obs_loss.mean()
            loss += self.config.train.obs_scale * obs_loss
            metrics.update({"obs_loss": obs_loss.item()})

            # generate global embedding for agents
            agent_embeddings = self.global_agent_embedding_transformer(post_latent["deter"], num_batch_dims=2)
            # mask the prediction of the first step
            not_first = ~is_first.unsqueeze(-2).repeat(1, 1, self.n_agents, 1)

            # 3. continuation prediction loss
            if self.config.world_model.rssm.use_cont_predictor:
                if self.config.world_model.rssm.cont_predictor.enable_feat_grad:
                    cont_output = self.cont_predictor(agent_embeddings)
                else:
                    cont_output = self.cont_predictor(agent_embeddings.detach())
                terminated = terminated.unsqueeze(2).repeat(1, 1, self.n_agents, 1)
                assert terminated.shape == agent_mask.shape, (terminated.shape, agent_mask.shape)
                cont_loss = cont_output.loss(terminated.float())
                cont_loss = (cont_loss * not_first).sum() / (not_first.sum() + 1e-5)
                loss += self.config.train.cont_scale * cont_loss
                metrics.update({"cont_loss": cont_loss.item()})

                cont_acc = (cont_output.mode == terminated).float()
                cont_acc = (cont_acc * not_first).sum() / (not_first.sum() + 1e-5)
                metrics.update({"cont_acc": cont_acc.item()})

            # 4. reward prediction loss
            if self.config.world_model.rssm.reward_predictor.enable_feat_grad:
                reward_output = self.reward_predictor(agent_embeddings)
            else:
                reward_output = self.reward_predictor(agent_embeddings.detach())
            reward_loss = reward_output.loss(rewards)
            reward_loss = (reward_loss * not_first).sum() / (not_first.sum() + 1e-5)
            loss += self.config.train.reward_scale * reward_loss
            metrics.update({"reward_loss": reward_loss.item()})

            # log means squared reward error
            reward_preds = reward_output.pred()
            assert reward_preds.shape == rewards.shape, (reward_preds.shape, rewards.shape)
            reward_loss_mse = (reward_preds - rewards) ** 2
            reward_loss_mse = (reward_loss_mse * not_first).sum() / (not_first.sum() + 1e-5)
            metrics.update({"reward_loss_mse": reward_loss_mse.item()})

            # 5. action mask prediction loss
            if self.config.world_model.rssm.use_act_mask_predictor:
                if self.config.world_model.rssm.act_mask_predictor.enable_feat_grad:
                    act_mask_output = self.act_mask_predictor(agent_embeddings)
                else:
                    act_mask_output = self.act_mask_predictor(agent_embeddings.detach())
                act_mask_loss = act_mask_output.loss(avail_actions).sum(dim=-1, keepdim=True)
                act_mask_loss = (act_mask_loss * not_first).sum() / (not_first.sum() + 1e-5)
                loss += self.config.train.act_mask_scale * act_mask_loss
                metrics.update({"act_mask_loss": act_mask_loss.item()})

                act_mask_acc = (act_mask_output.mode == avail_actions).float().all(dim=-1, keepdim=True)
                act_mask_acc = (act_mask_acc * not_first).sum() / (not_first.sum() + 1e-5)
                metrics.update({"act_mask_acc": act_mask_acc.item()})

            optim_metric = self._optim(loss)

        metrics.update(optim_metric)
        return metrics

    def initialize_agent_state(self, batch_size: int) -> TensorDict:
        agent_states = TensorDict(
            {
                "latents": {
                    "deter": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.deterministic_dim,
                    ),
                    "stoch": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.stochastic_dim,
                        self.config.world_model.rssm.classes,
                    ),
                    "logits": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.stochastic_dim,
                        self.config.world_model.rssm.classes,
                    ),
                },
                "actions": torch.zeros(
                    batch_size,
                    self.n_agents,
                    self.n_actions,
                ),
            },
            batch_size=(batch_size, self.n_agents),
            device=self.device,
        )
        return agent_states

    def observe_step(
        self,
        embed: torch.Tensor,
        prev_actions: torch.Tensor,
        is_first: torch.Tensor,
        prev_latent: TensorDict | None,
    ) -> TensorDict:
        if prev_latent is None:
            agent_states = self.initialize_agent_state(batch_size=embed.size(0))
            prev_latent = {
                "deter": agent_states["latents"]["deter"],
                "stoch": agent_states["latents"]["stoch"],
            }
            prev_actions = agent_states["actions"]
        elif is_first.any():
            agent_states = self.initialize_agent_state(batch_size=is_first.sum().item())
            prev_latent["deter"][is_first.squeeze(-1)] = agent_states["latents"]["deter"]
            prev_latent["stoch"][is_first.squeeze(-1)] = agent_states["latents"]["stoch"]
            prev_actions = prev_actions.clone()
            prev_actions[is_first.squeeze(-1)] = agent_states["actions"]

        latent: TensorDict = self.dynamics.observe(
            embed=embed,
            prev_actions=prev_actions,
            prev_latent=prev_latent,
        )
        return latent

    def save(self):
        data = {
            "model_state_dict": self.state_dict(),
            "optim_state_dict": self._optim.state_dict(),
        }
        return data

    def load(self, data):
        self.load_state_dict(data["model_state_dict"])
        self._optim.load_state_dict(data["optim_state_dict"])

class RSSM(nn.Module):
    def __init__(self, config, n_actions, n_agents, device):
        super().__init__()
        self.config = config
        self.n_actions = n_actions
        self.n_agents = n_agents
        self.device = device

        self.tot_stoch_dim = config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes
        self._mlp_stoch = MLP(
            in_dim=self.tot_stoch_dim,
            hidden_dim=config.world_model.rssm.mlp.hidden_dim,
            hidden_layers=config.world_model.rssm.mlp.hidden_layers,
            act=config.world_model.rssm.mlp.act,
            use_layernorm=config.world_model.rssm.mlp.use_layernorm,
            device=device,
        )
        self._mlp_action = MLP(
            in_dim=n_actions,
            hidden_dim=config.world_model.rssm.mlp.hidden_dim,
            hidden_layers=config.world_model.rssm.mlp.hidden_layers,
            act=config.world_model.rssm.mlp.act,
            use_layernorm=config.world_model.rssm.mlp.use_layernorm,
            device=device,
        )

        self._rnn = GRU(
            in_dim=2 * config.world_model.rssm.mlp.hidden_dim,
            hidden_dim=config.world_model.rssm.deterministic_dim,
            use_layernorm=config.world_model.rssm.rnn.use_layernorm,
            device=device,
        )

        self._obs_stoch_logits = MLP(
            in_dim=config.world_model.rssm.deterministic_dim+config.world_model.encoder.hidden_dim,
            hidden_dim=config.world_model.rssm.hidden_dim,
            hidden_layers=config.world_model.rssm.obs_layers,
            out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
            act=config.world_model.rssm.act,
            use_layernorm=config.world_model.rssm.use_layernorm,
            device=device,
        )
        if config.world_model.rssm.use_img_stoch_transformer:
            self._img_stoch_logits = Transformer(
                d_model=config.world_model.rssm.deterministic_dim,
                out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
                nhead=config.world_model.rssm.img_stoch_transformer.nhead,
                num_layers=config.world_model.rssm.img_stoch_transformer.num_layers,
                act=config.world_model.rssm.img_stoch_transformer.activation,
                norm_first=config.world_model.rssm.img_stoch_transformer.norm_first,
                device=device,
            )
        else:
            self._img_stoch_logits = MLP(
                in_dim=config.world_model.rssm.deterministic_dim,
                hidden_dim=config.world_model.rssm.hidden_dim,
                hidden_layers=config.world_model.rssm.obs_layers,
                out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
                act=config.world_model.rssm.act,
                use_layernorm=config.world_model.rssm.use_layernorm,
                device=device,
            )

    def observe(
            self,
            embed: torch.Tensor,
            prev_actions: torch.Tensor,
            prev_latent: TensorDict | Dict[str, torch.Tensor],
        ) -> TensorDict:
        # 1. generate determinisitic embedding
        prev_deter = prev_latent["deter"]
        prev_stoch = prev_latent["stoch"]

        x1 = self._mlp_stoch(
            prev_stoch.reshape(-1, self.tot_stoch_dim)
        ).reshape(*prev_stoch.shape[:-2], -1)
        x2 = self._mlp_action(prev_actions)
        x = torch.cat([x1, x2], dim=-1)
        deter = self._rnn(x, prev_deter)

        # 2. generate stochastic embedding
        input = torch.cat([embed, deter], dim=-1)
        stoch_logits = self._obs_stoch_logits(input)
        stoch_logits = stoch_logits.reshape(
            *input.shape[:-1],
            self.config.world_model.rssm.stochastic_dim,
            self.config.world_model.rssm.classes,
        )
        # shoch_logits.shape = (batch_size, n_agents, n_stochastic_dim, n_classes)
        stoch = get_st_onehot(stoch_logits, unimix=self.config.world_model.rssm.unimix).rsample()

        latent = TensorDict(
            {
                "deter": deter,
                "logits": stoch_logits,
                "stoch": stoch,
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return latent

    def imagine(
            self,
            prev_actions: torch.Tensor,
            prev_latent: Dict[str, torch.Tensor],
        ) -> TensorDict:
        # 1. generate deterministic embedding at next timestep
        prev_deter = prev_latent["deter"]
        prev_stoch = prev_latent["stoch"]

        x1 = self._mlp_stoch(
            prev_stoch.reshape(-1, self.tot_stoch_dim)
        ).reshape(*prev_stoch.shape[:-2], -1)
        x2 = self._mlp_action(prev_actions)
        x = torch.cat([x1, x2], dim=-1)
        deter = self._rnn(x, prev_deter)

        # 2. generate stochastic embedding at next timestep
        prior_stoch = self.imagine_stoch_from_deter(deter, num_batch_dims=1)

        latent = TensorDict(
            {
                "deter": deter,
                "logits": prior_stoch["logits"],
                "stoch": prior_stoch["stoch"],
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return latent

    def imagine_stoch_from_deter(
        self,
        deter: torch.Tensor,
        num_batch_dims: int = 1,
    ) -> TensorDict:
        """
        Args:
            deter: shape = (ts, bs, n_agents, n_deterministic_dim)
        """
        if self.config.world_model.rssm.use_img_stoch_transformer:
            stoch_logits = self._img_stoch_logits(deter, num_batch_dims=num_batch_dims)
        else:
            stoch_logits = self._img_stoch_logits(deter)
        stoch_logits = stoch_logits.reshape(
            *deter.shape[:-1],
            self.config.world_model.rssm.stochastic_dim,
            self.config.world_model.rssm.classes,
        )
        # stoch_logits.shape = (ts, bs, n_agents, n_stochastic_dim, n_classes)

        prior = TensorDict(
            {
                "logits": stoch_logits,
                "stoch": get_st_onehot(stoch_logits, unimix=self.config.world_model.rssm.unimix).rsample(), # re-parametrization trick
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return prior

class ObsEncoder(nn.Module):
    def __init__(self, config, obs_shape, device):
        super().__init__()
        self.config = config
        self.obs_shape = obs_shape
        if isinstance(obs_shape, int):
            obs_shape = (obs_shape,)

        if len(obs_shape) == 1:
            obs_size = obs_shape[0]
            self._base = MLP(
                in_dim=obs_size,
                hidden_dim=config.world_model.encoder.hidden_dim,
                hidden_layers=config.world_model.encoder.hidden_layers,
                act=config.world_model.encoder.act,
                use_layernorm=config.world_model.encoder.use_layernorm,
                use_symlog=config.world_model.encoder.use_symlog,
                device=device,
            )
        elif len(obs_shape) == 3:
            self._base = CNN(
                input_shape=obs_shape,
                depth=config.world_model.encoder.depth,
                mults=config.world_model.encoder.mults,
                kernel=config.world_model.encoder.kernel,
                act=config.world_model.encoder.act,
                output_dim=config.world_model.encoder.hidden_dim,
                use_layernorm=config.world_model.encoder.use_layernorm,
                device=device,
            )
        else:
            raise ValueError(f"Observation shape {obs_shape} not supported")

    def forward(self, obs: torch.Tensor):
        return self._base(obs)
