# models/vae_decoder.py
"""
VAE decoder interface: latent -> RGB image.
In real SD, the VAE decoder takes latent (z) and decodes to 3xHxW image.
This file provides a tiny placeholder decoder for debugging and to support gradient flows.

Interface:
  decoder.decode(latent_tensor) -> rgb_tensor in [0,1], shape [B,3,H,W]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class DummyVAEDecoder(nn.Module):
    def __init__(self, latent_dim=512, out_shape=(3,64,64), device="cpu"):
        super().__init__()
        self.latent_dim = latent_dim
        self.out_shape = out_shape
        self.device = device
        # small MLP to expand latent to image
        out_size = out_shape[0] * out_shape[1] * out_shape[2]
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, out_size)
        ).to(self.device)

    def decode(self, z):
        """
        z: torch.Tensor [B, latent_dim]
        returns: rgb in [0,1], torch.Tensor [B, 3, H, W]
        """
        B = z.shape[0]
        out = self.net(z.to(self.device))
        out = out.view(B, *self.out_shape)
        out = torch.sigmoid(out)  # map to [0,1]
        return out
