import torch
import torch.nn as nn
from torch.distributions import Independent, kl_divergence
from tensordict import TensorDict
import numpy as np

from typing import Dict, List, Tuple, Callable

from utils.tools import n2t, scan, RequiresGrad, Optimizer, latent_to_input, get_st_onehot, build_returns, build_cumulative_log_probs
from utils.networks import MLP, GRU, CNN, Transformer
from utils.output_head import MLPHead, FigureHead
from actor_critic.actor import Actor
from actor_critic.critic import Critic
import elements

class MAWorldModel(nn.Module):
    _name = "world_model"
    def __init__(self, config, obs_shape, n_actions, n_agents, device):
        super().__init__()
        self.config = config
        self.obs_shape = obs_shape
        self.n_actions = n_actions
        self.n_agents = n_agents
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        # world model
        self.encoder = ObsEncoder(config, obs_shape, device)
        self.dynamics = RSSM(config, n_actions, n_agents, device)
        # obs predictor
        obs_predictor_in_dim = (
            config.world_model.rssm.deterministic_dim
            + config.world_model.rssm.stochastic_dim
            * config.world_model.rssm.classes
        )
        if len(obs_shape) == 1:
            self.obs_predictor = MLPHead(
                in_dim=obs_predictor_in_dim,
                hidden_dim=config.world_model.rssm.obs_predictor.hidden_dim,
                hidden_layers=config.world_model.rssm.obs_predictor.hidden_layers,
                out_dim=obs_shape[0],
                act=config.world_model.rssm.obs_predictor.act,
                use_layernorm=config.world_model.rssm.obs_predictor.use_layernorm,
                output=config.world_model.rssm.obs_predictor.output,
                device=device,
            )
        elif len(obs_shape) == 3:
            self.obs_predictor = FigureHead(
                in_dim=obs_predictor_in_dim,
                out_shape=obs_shape,
                depth=config.world_model.rssm.obs_predictor.depth,
                mults=config.world_model.rssm.obs_predictor.mults,
                kernel=config.world_model.rssm.obs_predictor.kernel,
                act=config.world_model.rssm.obs_predictor.act,
                use_layernorm=config.world_model.rssm.obs_predictor.use_layernorm,
                device=device,
            )
        else:
            raise NotImplementedError

        # other predictors
        if config.world_model.rssm.use_global_agent_embedding_transformer:
            self.global_agent_embedding_transformer = Transformer(
                d_model=config.world_model.rssm.deterministic_dim,
                nhead=config.world_model.rssm.global_agent_embedding_transformer.nhead,
                num_layers=config.world_model.rssm.global_agent_embedding_transformer.num_layers,
                act=config.world_model.rssm.global_agent_embedding_transformer.activation,
                norm_first=config.world_model.rssm.global_agent_embedding_transformer.norm_first,
                device=device,
            )
            other_predictor_in_dim = config.world_model.rssm.deterministic_dim
        else:
            other_predictor_in_dim = obs_predictor_in_dim
        self.other_predictor_in_dim = other_predictor_in_dim

        self.reward_predictor = MLPHead(
            in_dim=other_predictor_in_dim,
            hidden_dim=config.world_model.rssm.reward_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.reward_predictor.hidden_layers,
            out_dim=1,
            act=config.world_model.rssm.reward_predictor.act,
            use_layernorm=config.world_model.rssm.reward_predictor.use_layernorm,
            output=config.world_model.rssm.reward_predictor.output,
            out_scale=config.world_model.rssm.reward_predictor.out_scale,
            device=device,
        )
        self.cont_predictor = MLPHead(
            in_dim=other_predictor_in_dim,
            hidden_dim=config.world_model.rssm.cont_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.cont_predictor.hidden_layers,
            out_dim=1,
            act=config.world_model.rssm.cont_predictor.act,
            use_layernorm=config.world_model.rssm.cont_predictor.use_layernorm,
            output=config.world_model.rssm.cont_predictor.output,
            device=device,
        ) if config.world_model.rssm.use_cont_predictor else None
        self.act_mask_predictor = MLPHead(
            in_dim=other_predictor_in_dim,
            hidden_dim=config.world_model.rssm.act_mask_predictor.hidden_dim,
            hidden_layers=config.world_model.rssm.act_mask_predictor.hidden_layers,
            out_dim=n_actions,
            act=config.world_model.rssm.act_mask_predictor.act,
            use_layernorm=config.world_model.rssm.act_mask_predictor.use_layernorm,
            output=config.world_model.rssm.act_mask_predictor.output,
            device=device,
        ) if config.world_model.rssm.use_act_mask_predictor else None

        # context manager for gradient control
        self._requires_grad = RequiresGrad(self)

        # optimizer
        self._optim = Optimizer(
            name=self._name,
            parameters=self.parameters(),
            lr=config.train.optim.world_model.lr,
            eps=config.train.optim.world_model.eps,
            use_max_grad_norm=config.train.optim.world_model.use_max_grad_norm,
            max_grad_norm=config.train.optim.world_model.max_grad_norm,
        )

    def observe(
        self,
        obs: torch.Tensor | np.ndarray,
        prev_actions: torch.Tensor | np.ndarray,
        is_first: torch.Tensor | np.ndarray,
    ) -> TensorDict:
        obs = n2t(obs, **self.tpdv) if isinstance(obs, np.ndarray) else obs
        prev_actions = n2t(prev_actions, **self.tpdv) if isinstance(prev_actions, np.ndarray) else prev_actions
        is_first = n2t(is_first, device=self.device, dtype=torch.bool) if isinstance(is_first, np.ndarray) else is_first
        # obs.shape = (ts, bs, n_agents, obs_shape)
        # prev_actions.shape = (ts, bs, n_agents, n_actions)
        # is_first.shape = (ts, bs, 1)

        embed = self.encoder(obs)
        post_latent: List[TensorDict] = scan(
            fn=self.observe_step,
            inputs=(
                embed,
                prev_actions,
                is_first,
            ),
            init_state=None,
        )
        return post_latent

    def prep_init_latent_for_imagination(self, samples: Dict[str, np.ndarray]) -> TensorDict:
        obs = n2t(samples["obs"], **self.tpdv)
        prev_actions = n2t(samples["prev_actions"], **self.tpdv)
        rewards = n2t(samples["rewards"], **self.tpdv)
        is_first = n2t(samples["is_first"], device=self.device, dtype=torch.bool)
        agent_mask = n2t(samples["agent_mask"], **self.tpdv)
        avail_actions = n2t(samples["avail_actions"], **self.tpdv) if self.act_mask_predictor is not None else None
        terminated = (
            n2t(samples["is_trailing_absorbing_state"], device=self.device, dtype=torch.bool)
            if self.cont_predictor is not None
            else n2t(samples["terminated"], device=self.device, dtype=torch.bool)
        )
        truncated = n2t(samples["truncated"], device=self.device, dtype=torch.bool)
        agent_policy_index = n2t(samples["agent_policy_index"], device=self.device, dtype=torch.long) if "agent_policy_index" in samples else None

        # generate post latent as initial state for imagination
        init_latent: List[TensorDict] = self.observe(
            obs=obs,
            prev_actions=prev_actions,
            is_first=is_first,
        )
        init_latent: TensorDict = torch.stack(init_latent, dim=0)

        # fill in the rest of the initial state
        init_latent["agent_mask"] = agent_mask
        init_latent["rewards"] = rewards
        init_latent["avail_actions"] = avail_actions if avail_actions is not None else None
        init_latent["terminated"] = terminated.unsqueeze(2).repeat(1, 1, self.n_agents, 1) if terminated is not None else None
        init_latent["agent_policy_index"] = agent_policy_index
        init_latent = init_latent.flatten(0, 1)
        return init_latent

    @elements.timer.section("imagination")
    @torch.no_grad()
    def imagine(
        self,
        actors: List[Actor],
        init_latent: Dict[str, torch.Tensor],
        policy_mask_dict: Dict[int, torch.Tensor] | None = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Generate imaginary rollouts given previous latent states and actor policies
        """
        # pre-allocate tensors
        B = init_latent["deter"].shape[0]
        T = self.config.train.imagination_steps
        A = self.n_agents
        deter = torch.zeros((T+1, B, A, self.config.world_model.rssm.deterministic_dim), device=self.device)
        stoch = torch.zeros((T+1, B, A, self.config.world_model.rssm.stochastic_dim, self.config.world_model.rssm.classes), device=self.device)
        rewards = torch.zeros((T+1, B, A, 1), device=self.device)
        actions_env = torch.zeros((T, B, A), device=self.device)
        terminated = torch.zeros((T+1, B, A, 1), device=self.device)
        avail_actions = torch.zeros((T+1, B, A, self.n_actions), device=self.device) if self.act_mask_predictor is not None else None

        # set initial values
        deter[0] = init_latent["deter"]
        stoch[0] = init_latent["stoch"]
        rewards[0] = init_latent["rewards"]
        if self.cont_predictor is not None:
            terminated[0] = init_latent["terminated"] 
        if self.act_mask_predictor is not None:
            avail_actions[0] = init_latent["avail_actions"]

        def imagine_step(step: torch.Tensor, latent: TensorDict) -> Dict[str, torch.Tensor]:
            # imagine next state
            if policy_mask_dict is not None:
                actor_outputs = TensorDict(batch_size=(B, A), device=self.device)
                for policy_idx, policy_mask in policy_mask_dict.items():
                    policy_output = actors[policy_idx](
                        latent=latent[policy_mask],
                        avail_actions=avail_actions[step][policy_mask] if avail_actions is not None else None,
                    )
                    actor_outputs[policy_mask] = policy_output
            else:
                actor_outputs: List[TensorDict] = [
                    actors[i](
                        latent=latent[:, i],
                        avail_actions=avail_actions[step, :, i] if self.act_mask_predictor is not None else None,
                        evaluation=False,
                    )
                    for i in range(len(actors))
                ]
                actor_outputs = torch.stack(actor_outputs, dim=1)
            next_latent: TensorDict = self.dynamics.imagine(actor_outputs["actions"], latent)

            # fill in next state
            actions_env[step] = actor_outputs["actions_env"]
            deter[step+1] = next_latent["deter"]
            stoch[step+1] = next_latent["stoch"]
            if self.config.world_model.rssm.use_global_agent_embedding_transformer:
                other_predictor_inputs = self.global_agent_embedding_transformer(deter[step+1])
            else:
                other_predictor_inputs = latent_to_input(
                    {
                        "deter": deter[step+1],
                        "stoch": stoch[step+1],
                    }
                )
            rewards[step+1] = self.reward_predictor(other_predictor_inputs).pred()
            if self.cont_predictor is not None:
                terminated[step+1] = self.cont_predictor(other_predictor_inputs).pred()
            if self.act_mask_predictor is not None:
                avail_actions[step+1] = self.act_mask_predictor(other_predictor_inputs).sample()
            return next_latent

        scan(
            fn = imagine_step,
            inputs=(torch.arange(self.config.train.imagination_steps),),
            init_state=init_latent,
        )
        imaginary_transitions: Dict[str, torch.Tensor] = {
            "deter": deter,
            "stoch": stoch,
            "terminated": terminated,
            "rewards": rewards,
            "actions_env": actions_env,
        }
        if self.act_mask_predictor is not None:
            imaginary_transitions["avail_actions"] = avail_actions
        return imaginary_transitions

    def post_prior_loss(
            self,
            post: torch.Tensor,
            prior: torch.Tensor,
            loss_fn: Callable,
            transform: Callable = lambda x: x,
            free: float = 0.0,
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        # calculate dynamic loss
        dyn_loss = loss_fn(transform(post.detach()), transform(prior))
        dyn_loss = dyn_loss.clamp(min=free)
        # calculate representation loss
        rep_loss = loss_fn(transform(post), transform(prior.detach()))
        rep_loss = rep_loss.clamp(min=free)

        return dyn_loss, rep_loss

    def _latent_dist(self, logits: torch.Tensor, unimix: float = 0.0):
        categorical = get_st_onehot(logits, unimix=unimix)
        independent = Independent(categorical, 1)
        return independent

    @elements.timer.section("update_world_model")
    def update(
        self,
        samples: Dict[str, np.ndarray],
        actors: List[Actor],
        critics: List[Critic],
        target_critics: List[Critic],
    ) -> Dict[str, float]:
        obs = n2t(samples["obs"], **self.tpdv)
        prev_actions = n2t(samples["prev_actions"], **self.tpdv)
        prev_log_probs = n2t(samples["prev_log_probs"], **self.tpdv)
        rewards = n2t(samples["rewards"], **self.tpdv)
        is_first = n2t(samples["is_first"], device=self.device, dtype=torch.bool)
        agent_mask = n2t(samples["agent_mask"], **self.tpdv)
        avail_actions = n2t(samples["avail_actions"], **self.tpdv) if self.act_mask_predictor is not None else None
        terminated = (
            n2t(samples["is_trailing_absorbing_state"], device=self.device, dtype=torch.bool)
            if self.cont_predictor is not None
            else n2t(samples["terminated"], device=self.device, dtype=torch.bool)
        )
        truncated = n2t(samples["truncated"], device=self.device, dtype=torch.bool)
        # obs.shape = (ts, bs, n_agents, obs_shape)
        # prev_actions.shape = (ts, bs, n_agents, n_actions)
        # prev_log_probs.shape = (ts, bs, n_agents, n_actions)
        # rewards.shape = (ts, bs, n_agents, 1)
        # is_first.shape = (ts, bs, 1)
        # agent_mask.shape = (ts, bs, n_agents, 1)
        # avail_actions.shape = (ts, bs, n_agents, n_actions)
        # terminated.shape = (ts, bs, 1)
        # truncated.shape = (ts, bs, 1)
        metrics = {}
        loss = 0

        with self._requires_grad:
            post_latent: List[TensorDict] = self.observe(
                obs=obs,
                prev_actions=prev_actions,
                is_first=is_first,
            )
            # remove the burn-in steps
            if self.config.train.burn_in_length:
                obs = obs[self.config.train.burn_in_length:]
                prev_actions = prev_actions[self.config.train.burn_in_length:]
                prev_log_probs = prev_log_probs[self.config.train.burn_in_length:]
                rewards = rewards[self.config.train.burn_in_length:]
                is_first = is_first[self.config.train.burn_in_length:]
                agent_mask = agent_mask[self.config.train.burn_in_length:]
                post_latent = post_latent[self.config.train.burn_in_length:]
                avail_actions = avail_actions[self.config.train.burn_in_length:] if avail_actions is not None else None
                terminated = terminated[self.config.train.burn_in_length:]
                truncated = truncated[self.config.train.burn_in_length:]
            post_latent: TensorDict = torch.stack(post_latent, dim=0)

            # 1. stochastic latent loss
            prior_stoch: TensorDict = self.dynamics.imagine_stoch_from_deter(post_latent["deter"], num_batch_dims=2)
            dyn_loss, rep_loss = self.post_prior_loss(
                post=post_latent["logits"],
                prior=prior_stoch["logits"],
                loss_fn=kl_divergence,
                transform=lambda x: self._latent_dist(x, unimix=self.config.world_model.rssm.unimix),
                free=self.config.train.free_bits,
            )
            assert dyn_loss.unsqueeze(-1).shape == agent_mask.shape, (dyn_loss.shape, agent_mask.shape)
            assert rep_loss.unsqueeze(-1).shape == agent_mask.shape, (rep_loss.shape, agent_mask.shape)
            dyn_loss = dyn_loss.mean()
            rep_loss = rep_loss.mean()
            loss += self.config.train.stoch_dyn_scale * dyn_loss
            loss += self.config.train.stoch_rep_scale * rep_loss
            metrics.update({"dyn_loss": dyn_loss.item(), "rep_loss": rep_loss.item()})

            # log post and prior latent entropy
            prior_dist = self._latent_dist(prior_stoch["logits"], unimix=self.config.world_model.rssm.unimix)
            post_dist = self._latent_dist(post_latent["logits"], unimix=self.config.world_model.rssm.unimix)
            prior_entropy = prior_dist.entropy().mean()
            post_entropy = post_dist.entropy().mean()

            metrics.update(
                {
                    "post_stoch_entropy": post_entropy.item(),
                    "prior_stoch_entropy": prior_entropy.item(),
                }
            )

            obs_predictor_inputs = latent_to_input(post_latent)
            # 2. obs prediction loss
            obs_output = self.obs_predictor(obs_predictor_inputs)
            obs_loss = obs_output.loss(obs)
            # obs_loss.shape = (ts, bs, n_agents, ...)
            obs_loss = obs_loss.sum(dim=list(range(3, obs.ndim))).unsqueeze(-1)
            assert obs_loss.shape == agent_mask.shape, (obs_loss.shape, agent_mask.shape)
            obs_loss = obs_loss.mean()
            loss += self.config.train.obs_scale * obs_loss
            metrics.update({"obs_loss": obs_loss.item()})

            # generate global embedding for agents
            if self.config.world_model.rssm.use_global_agent_embedding_transformer:
                other_predictor_inputs = self.global_agent_embedding_transformer(post_latent["deter"], num_batch_dims=2)
            else:
                other_predictor_inputs = obs_predictor_inputs

            # mask the prediction of the first step
            not_first = ~is_first.unsqueeze(-2).repeat(1, 1, self.n_agents, 1)

            # 3. continuation prediction loss
            terminated = terminated.unsqueeze(2).repeat(1, 1, self.n_agents, 1)
            if self.config.world_model.rssm.use_cont_predictor:
                if self.config.world_model.rssm.cont_predictor.enable_feat_grad:
                    cont_output = self.cont_predictor(other_predictor_inputs)
                else:
                    cont_output = self.cont_predictor(other_predictor_inputs.detach())
                assert terminated.shape == agent_mask.shape, (terminated.shape, agent_mask.shape)
                cont_loss = cont_output.loss(terminated.float())
                cont_loss = (cont_loss * not_first).sum() / (not_first.sum() + 1e-5)
                loss += self.config.train.cont_scale * cont_loss
                metrics.update({"cont_loss": cont_loss.item()})

                cont_acc = (cont_output.mode == terminated).float()
                cont_acc = (cont_acc * not_first).sum() / (not_first.sum() + 1e-5)
                metrics.update({"cont_acc": cont_acc.item()})

            # 4. reward prediction loss
            if self.config.world_model.rssm.reward_predictor.enable_feat_grad:
                reward_output = self.reward_predictor(other_predictor_inputs)
            else:
                reward_output = self.reward_predictor(other_predictor_inputs.detach())
            reward_loss = reward_output.loss(rewards)
            reward_loss = (reward_loss * not_first).sum() / (not_first.sum() + 1e-5)
            loss += self.config.train.reward_scale * reward_loss
            metrics.update({"reward_loss": reward_loss.item()})

            # log means squared reward error
            reward_preds = reward_output.pred()
            assert reward_preds.shape == rewards.shape, (reward_preds.shape, rewards.shape)
            reward_loss_mse = (reward_preds - rewards) ** 2
            reward_loss_mse = (reward_loss_mse * not_first).sum() / (not_first.sum() + 1e-5)
            metrics.update({"reward_loss_mse": reward_loss_mse.item()})

            # 5. action mask prediction loss
            if self.config.world_model.rssm.use_act_mask_predictor:
                if self.config.world_model.rssm.act_mask_predictor.enable_feat_grad:
                    act_mask_output = self.act_mask_predictor(other_predictor_inputs)
                else:
                    act_mask_output = self.act_mask_predictor(other_predictor_inputs.detach())
                act_mask_loss = act_mask_output.loss(avail_actions).sum(dim=-1, keepdim=True)
                act_mask_loss = (act_mask_loss * not_first).sum() / (not_first.sum() + 1e-5)
                loss += self.config.train.act_mask_scale * act_mask_loss
                metrics.update({"act_mask_loss": act_mask_loss.item()})

                act_mask_acc = (act_mask_output.mode == avail_actions).float().all(dim=-1, keepdim=True)
                act_mask_acc = (act_mask_acc * not_first).sum() / (not_first.sum() + 1e-5)
                metrics.update({"act_mask_acc": act_mask_acc.item()})

            # 6. value prediction loss
            if self.config.train.share_critics:
                target_critic_outputs = target_critics[0](post_latent)
            else:
                target_critic_outputs: List[TensorDict] = [
                    target_critics[i](
                        latent=post_latent[:, :, i],
                    )
                    for i in range(len(target_critics))
                ]
                target_critic_outputs = torch.stack(target_critic_outputs, dim=2)
            value_preds = target_critic_outputs["value_preds"]

            truncated = truncated.unsqueeze(-2).expand_as(terminated)
            target_returns = build_returns(
                rewards=rewards,
                value_preds=value_preds,
                terminated=terminated.float(),
                truncated=truncated.float(),
                gamma=self.config.train.gamma,
                gae_lambda=self.config.train.gae_lambda,
            )

            # calculate the importance weights
            actor_outputs: List[TensorDict] = [
                actors[i](
                    latent=post_latent[:, :, i],
                    avail_actions=avail_actions[:, :, i] if avail_actions is not None else None,
                )
                for i in range(len(actors))
            ]
            actor_outputs = torch.stack(actor_outputs, dim=2)
            log_probs = actor_outputs["log_probs"]
            action_log_probs = (log_probs[:-1] * prev_actions[1:]).sum(dim=-1, keepdim=True)
            old_action_log_probs = (prev_log_probs[1:] * prev_actions[1:]).sum(dim=-1, keepdim=True)

            # Build cumulative log probs for importance weighting
            # Uses terminated/truncated flags to reset accumulation at episode boundaries
            cum_action_log_probs = build_cumulative_log_probs(
                action_log_probs,
                terminated[:-1].float(),
                truncated[:-1].float(),
            )
            old_cum_action_log_probs = build_cumulative_log_probs(
                old_action_log_probs,
                terminated[:-1].float(),
                truncated[:-1].float(),
            )

            # Calculate importance weights
            imp_weights = torch.exp(cum_action_log_probs.sum(dim=-2, keepdim=True) - old_cum_action_log_probs.sum(dim=-2, keepdim=True))

            if self.config.train.share_critics:
                # First pass: compute value loss with gradients flowing through critic to world model
                # This allows world model to learn representations useful for value prediction
                # Keep critic gradients enabled so gradients can flow through critic to world model
                with critics[0]._requires_grad:
                    value_output = critics[0](post_latent[:-1])["value_output"]
                    value_loss_wm = value_output.loss(target_returns[:-1].detach())
                    imp_weights_expanded = imp_weights.expand_as(value_loss_wm).detach()
                    value_loss_wm = (value_loss_wm * imp_weights_expanded.clamp(max=4.0)).mean()
                    loss += value_loss_wm
            else:
                for i in range(len(critics)):
                    with critics[i]._requires_grad:
                        value_output = critics[i](post_latent[:-1, :, i])["value_output"]
                        value_loss_wm = value_output.loss(target_returns[:-1, :, i].detach())
                        imp_weights_expanded = imp_weights.expand_as(value_loss_wm).detach()
                        value_loss_wm = (value_loss_wm * imp_weights_expanded.clamp(max=4.0)).mean()
                        loss += value_loss_wm

            optim_metric = self._optim(loss)

            # Second pass: update critic separately AFTER world model update
            # This prevents in-place modification of critic parameters from breaking the computation graph
            if self.config.train.share_critics:
                with critics[0]._requires_grad:
                    value_output = critics[0](post_latent[:-1].detach())["value_output"]
                    value_loss = value_output.loss(target_returns[:-1].detach())
                    imp_weights_expanded = imp_weights.expand_as(value_loss).detach()
                    value_loss = (value_loss * imp_weights_expanded.clamp(max=4.0)).mean()
                    critics[0]._optim(value_loss)
                metrics.update({"value_loss": value_loss.item()})
                metrics.update({"value_loss_ratio": imp_weights_expanded.mean().item()})
            else:   
                for i in range(len(critics)):
                    with critics[i]._requires_grad:
                        value_output = critics[i](post_latent[:-1, :, i].detach())["value_output"]
                        value_loss = value_output.loss(target_returns[:-1, :, i].detach())
                        imp_weights_expanded = imp_weights.expand_as(value_loss).detach()
                        value_loss = (value_loss * imp_weights_expanded.clamp(max=4.0)).mean()
                        critics[i]._optim(value_loss)
                    metrics.update({f"value_loss_{i}": value_loss.item()})
                    metrics.update({f"value_loss_ratio_{i}": imp_weights_expanded.mean().item()})

        metrics.update(optim_metric)
        return metrics

    def initialize_agent_state(self, batch_size: int) -> TensorDict:
        agent_states = TensorDict(
            {
                "latents": {
                    "deter": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.deterministic_dim,
                    ),
                    "stoch": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.stochastic_dim,
                        self.config.world_model.rssm.classes,
                    ),
                    "logits": torch.zeros(
                        batch_size,
                        self.n_agents,
                        self.config.world_model.rssm.stochastic_dim,
                        self.config.world_model.rssm.classes,
                    ),
                },
                "actions": torch.zeros(
                    batch_size,
                    self.n_agents,
                    self.n_actions,
                ),
                "log_probs": torch.zeros(
                    batch_size,
                    self.n_agents,
                    self.n_actions,
                ),
            },
            batch_size=(batch_size, self.n_agents),
            device=self.device,
        )
        return agent_states

    def observe_step(
        self,
        embed: torch.Tensor,
        prev_actions: torch.Tensor,
        is_first: torch.Tensor,
        prev_latent: TensorDict | None,
    ) -> TensorDict:
        if prev_latent is None or is_first.all():
            agent_states = self.initialize_agent_state(batch_size=embed.size(0))
            prev_latent = agent_states["latents"]
            prev_actions = agent_states["actions"]
        elif is_first.any():
            agent_states = self.initialize_agent_state(batch_size=embed.size(0))
            prev_latent["deter"] = torch.where(
                is_first[..., None],
                agent_states["latents"]["deter"],
                prev_latent["deter"],
            )
            prev_latent["stoch"] = torch.where(
                is_first[..., None, None],
                agent_states["latents"]["stoch"],
                prev_latent["stoch"],
            )
            prev_actions = torch.where(
                is_first[..., None],
                agent_states["actions"],
                prev_actions,
            )

        latent: TensorDict = self.dynamics.observe(
            embed=embed,
            prev_actions=prev_actions,
            prev_latent=prev_latent,
        )
        return latent

    def save(self):
        data = {
            "model_state_dict": self.state_dict(),
            "optim_state_dict": self._optim.state_dict(),
        }
        return data

    def load(self, data):
        self.load_state_dict(data["model_state_dict"])
        self._optim.load_state_dict(data["optim_state_dict"])

class RSSM(nn.Module):
    def __init__(self, config, n_actions, n_agents, device):
        super().__init__()
        self.config = config
        self.n_actions = n_actions
        self.n_agents = n_agents
        self.device = device

        self.tot_stoch_dim = config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes
        self._mlp_stoch = MLP(
            in_dim=self.tot_stoch_dim,
            hidden_dim=config.world_model.rssm.mlp.hidden_dim,
            hidden_layers=config.world_model.rssm.mlp.hidden_layers,
            act=config.world_model.rssm.mlp.act,
            use_layernorm=config.world_model.rssm.mlp.use_layernorm,
            device=device,
        )
        self._mlp_action = MLP(
            in_dim=n_actions,
            hidden_dim=config.world_model.rssm.mlp.hidden_dim,
            hidden_layers=config.world_model.rssm.mlp.hidden_layers,
            act=config.world_model.rssm.mlp.act,
            use_layernorm=config.world_model.rssm.mlp.use_layernorm,
            device=device,
        )

        self._rnn = GRU(
            in_dim=2 * config.world_model.rssm.mlp.hidden_dim,
            hidden_dim=config.world_model.rssm.deterministic_dim,
            use_layernorm=config.world_model.rssm.rnn.use_layernorm,
            device=device,
        )

        self._obs_stoch_logits = MLP(
            in_dim=config.world_model.rssm.deterministic_dim+config.world_model.encoder.hidden_dim,
            hidden_dim=config.world_model.rssm.hidden_dim,
            hidden_layers=config.world_model.rssm.obs_layers,
            out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
            act=config.world_model.rssm.act,
            use_layernorm=config.world_model.rssm.use_layernorm,
            device=device,
        )
        if config.world_model.rssm.use_img_stoch_transformer:
            self._img_stoch_logits = Transformer(
                d_model=config.world_model.rssm.deterministic_dim,
                out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
                nhead=config.world_model.rssm.img_stoch_transformer.nhead,
                num_layers=config.world_model.rssm.img_stoch_transformer.num_layers,
                act=config.world_model.rssm.img_stoch_transformer.activation,
                norm_first=config.world_model.rssm.img_stoch_transformer.norm_first,
                device=device,
            )
        else:
            self._img_stoch_logits = MLP(
                in_dim=config.world_model.rssm.deterministic_dim,
                hidden_dim=config.world_model.rssm.hidden_dim,
                hidden_layers=config.world_model.rssm.obs_layers,
                out_dim=config.world_model.rssm.stochastic_dim * config.world_model.rssm.classes,
                act=config.world_model.rssm.act,
                use_layernorm=config.world_model.rssm.use_layernorm,
                device=device,
            )

    def observe(
            self,
            embed: torch.Tensor,
            prev_actions: torch.Tensor,
            prev_latent: TensorDict | Dict[str, torch.Tensor],
        ) -> TensorDict:
        # 1. generate determinisitic embedding
        prev_deter = prev_latent["deter"]
        prev_stoch = prev_latent["stoch"]

        x1 = self._mlp_stoch(
            prev_stoch.reshape(-1, self.tot_stoch_dim)
        ).reshape(*prev_stoch.shape[:-2], -1)
        x2 = self._mlp_action(prev_actions)
        x = torch.cat([x1, x2], dim=-1)
        deter = self._rnn(x, prev_deter)

        # 2. generate stochastic embedding
        input = torch.cat([embed, deter], dim=-1)
        stoch_logits = self._obs_stoch_logits(input)
        stoch_logits = stoch_logits.reshape(
            *input.shape[:-1],
            self.config.world_model.rssm.stochastic_dim,
            self.config.world_model.rssm.classes,
        )
        # shoch_logits.shape = (batch_size, n_agents, n_stochastic_dim, n_classes)
        stoch = get_st_onehot(stoch_logits, unimix=self.config.world_model.rssm.unimix).rsample()

        latent = TensorDict(
            {
                "deter": deter,
                "logits": stoch_logits,
                "stoch": stoch,
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return latent

    def imagine(
            self,
            prev_actions: torch.Tensor,
            prev_latent: Dict[str, torch.Tensor],
        ) -> TensorDict:
        # 1. generate deterministic embedding at next timestep
        prev_deter = prev_latent["deter"]
        prev_stoch = prev_latent["stoch"]

        x1 = self._mlp_stoch(
            prev_stoch.reshape(-1, self.tot_stoch_dim)
        ).reshape(*prev_stoch.shape[:-2], -1)
        x2 = self._mlp_action(prev_actions)
        x = torch.cat([x1, x2], dim=-1)
        deter = self._rnn(x, prev_deter)

        # 2. generate stochastic embedding at next timestep
        prior_stoch = self.imagine_stoch_from_deter(deter, num_batch_dims=1)

        latent = TensorDict(
            {
                "deter": deter,
                "logits": prior_stoch["logits"],
                "stoch": prior_stoch["stoch"],
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return latent

    def imagine_stoch_from_deter(
        self,
        deter: torch.Tensor,
        num_batch_dims: int = 1,
    ) -> TensorDict:
        """
        Args:
            deter: shape = (ts, bs, n_agents, n_deterministic_dim)
        """
        if self.config.world_model.rssm.use_img_stoch_transformer:
            stoch_logits = self._img_stoch_logits(deter, num_batch_dims=num_batch_dims)
        else:
            stoch_logits = self._img_stoch_logits(deter)
        stoch_logits = stoch_logits.reshape(
            *deter.shape[:-1],
            self.config.world_model.rssm.stochastic_dim,
            self.config.world_model.rssm.classes,
        )
        # stoch_logits.shape = (ts, bs, n_agents, n_stochastic_dim, n_classes)

        prior = TensorDict(
            {
                "logits": stoch_logits,
                "stoch": get_st_onehot(stoch_logits, unimix=self.config.world_model.rssm.unimix).rsample(), # re-parametrization trick
            },
            batch_size=deter.shape[:-1],
            device=self.device,
        )
        return prior

class ObsEncoder(nn.Module):
    def __init__(self, config, obs_shape, device):
        super().__init__()
        self.config = config
        self.obs_shape = obs_shape
        if isinstance(obs_shape, int):
            obs_shape = (obs_shape,)

        if len(obs_shape) == 1:
            obs_size = obs_shape[0]
            self._base = MLP(
                in_dim=obs_size,
                hidden_dim=config.world_model.encoder.hidden_dim,
                hidden_layers=config.world_model.encoder.hidden_layers,
                act=config.world_model.encoder.act,
                use_layernorm=config.world_model.encoder.use_layernorm,
                use_symlog=config.world_model.encoder.use_symlog,
                device=device,
            )
        elif len(obs_shape) == 3:
            self._base = CNN(
                input_shape=obs_shape,
                depth=config.world_model.encoder.depth,
                mults=config.world_model.encoder.mults,
                kernel=config.world_model.encoder.kernel,
                act=config.world_model.encoder.act,
                output_dim=config.world_model.encoder.hidden_dim,
                use_layernorm=config.world_model.encoder.use_layernorm,
                device=device,
            )
        else:
            raise ValueError(f"Observation shape {obs_shape} not supported")

    def forward(self, obs: torch.Tensor):
        return self._base(obs)
