# projector/mlp_projector.py
"""
Two-layer MLP projector P_phi: latent -> CLIP embedding.
This is the fast surrogate recommended in the paper to avoid repeated decode->CLIP passes.

Usage:
    projector = MLPProjector(latent_dim=512, clip_dim=768, hidden=512)
    emb = projector(latents)  # returns L2-normalized embeddings

Training script for this projector should be implemented separately (train_projector.py).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLPProjector(nn.Module):
    def __init__(self, latent_dim=512, clip_dim=768, hidden_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, clip_dim)
        )
        # note: do not register embedding normalization as parameter
    def forward(self, latents):
        x = self.net(latents)
        x = F.normalize(x, dim=1)
        return x

    @staticmethod
    def train_from_pairs(latents_np, clip_embs_np, device="cpu", epochs=20, lr=1e-3, batch_size=256):
        """
        Quick trainer: latents_np [N, D], clip_embs_np [N, CLIP_D]
        Returns a trained MLPProjector instance.
        """
        model = MLPProjector(latent_dim=latents_np.shape[1], clip_dim=clip_embs_np.shape[1]).to(device)
        optim = torch.optim.Adam(model.parameters(), lr=lr)
        dataset = torch.utils.data.TensorDataset(torch.from_numpy(latents_np).float(), torch.from_numpy(clip_embs_np).float())
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        for ep in range(epochs):
            model.train()
            total = 0.0
            for lat_b, clip_b in loader:
                lat_b = lat_b.to(device)
                clip_b = clip_b.to(device)
                pred = model(lat_b)
                loss = ((pred - clip_b).pow(2).sum(dim=1)).mean()
                optim.zero_grad()
                loss.backward()
                optim.step()
                total += loss.item() * lat_b.size(0)
        return model
