"""
causal_modules.py
Causal modular blocks for latent dynamics.
Implements parent–child factorization and disentangled representations.
"""

import torch
import torch.nn as nn
import torch.distributions as D


class CausalModule(nn.Module):
    """
    Single causal module for one latent factor z^(m).
    Conditioned on parents and actions.
    """

    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim),
        )

    def forward(self, x):
        params = self.net(x)
        mean, logvar = params.chunk(2, dim=-1)
        std = torch.exp(0.5 * logvar)
        dist = D.Normal(mean, std)
        return dist.rsample(), dist


class ModularDynamics(nn.Module):
    """
    Modular latent dynamics with M modules.
    Factorized conditional distribution p(z_{t+1} | z_t, a_t).
    """

    def __init__(self, num_modules, latent_dim, hidden_dim, parent_map=None):
        super().__init__()
        self.num_modules = num_modules
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.parent_map = parent_map or {m: [] for m in range(num_modules)}

        self.modules = nn.ModuleList(
            [
                CausalModule(len(self.parent_map[m]) * latent_dim + latent_dim, hidden_dim, latent_dim)
                for m in range(num_modules)
            ]
        )

    def forward(self, z_t, a_t):
        z_next = []
        dists = []
        for m, module in enumerate(self.modules):
            parents = [z_t[p] for p in self.parent_map[m]]
            inputs = torch.cat(parents + [a_t], dim=-1)
            z_m, dist = module(inputs)
            z_next.append(z_m)
            dists.append(dist)
        return torch.stack(z_next, dim=1), dists
