from __future__ import annotations
import torch
import torch.nn as nn


class LatentDiffusion(nn.Module):
    """DDPM-style noise predictor εθ(z_t, t). Stub only."""

    def __init__(self, latent_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + 1, 512),
            nn.SiLU(),
            nn.Linear(512, 512),
            nn.SiLU(),
            nn.Linear(512, latent_dim),
        )

    def forward(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        if t.ndim == 1:
            t = t[:, None]
        x = torch.cat([z_t, t], dim=-1)
        return self.net(x)
