# models/denoiser_wrapper.py
"""
Denoiser wrapper providing a uniform interface for different U-Net / denoiser models.
This file contains a lightweight placeholder denoiser for demo purposes.

Required interface:
  denoiser.predict_noise(x_t, t, cond) -> predicted_noise (same shape as x_t)
Replace with a real diffusion U-Net for production (e.g., a Stable Diffusion U-Net).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class DummyDenoiser:
    def __init__(self, latent_dim=512, device="cpu"):
        self.latent_dim = latent_dim
        self.device = device
        # small MLP as a stand-in
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
        self.net.to(self.device)

    def predict_noise(self, x_t, t_idx, cond=None):
        """
        x_t: tensor [B, D] (for latents) or [B, C, H, W] flattened
        Return: tensor of same shape (predicted noise)
        """
        if x_t.ndim > 2:
            # flatten spatial dims
            B = x_t.shape[0]
            flat = x_t.view(B, -1)
            out = self.net(flat.to(self.device))
            out = out.view_as(x_t)
            return out
        else:
            return self.net(x_t.to(self.device))
