# projector/train_projector_dataset_builder.py
"""
Build dataset of (latent, clip_emb) pairs by sampling latents and computing CLIP embeddings
for the decoded images.

This script will:
 - sample xt from model or load saved latents
 - decode to RGB via VAE.decode
 - compute CLIP embeddings
 - save pairs as .npz files in output directory.

Important: This is expensive (requires decoding and CLIP passes). Use projector to speed up after building dataset.
"""
import os
import numpy as np
from pathlib import Path
from models.vae_decoder import DummyVAEDecoder
from clip.clip_embedder import CLIPEmbedder

def build_pairs(out_dir="projector/pairs", n_pairs=1000, latent_dim=512):
    os.makedirs(out_dir, exist_ok=True)
    vae = DummyVAEDecoder(latent_dim=latent_dim, out_shape=(3,64,64))
    clip = CLIPEmbedder()
    for i in range(n_pairs):
        latent = np.random.randn(latent_dim).astype(np.float32)
        # decode & embed
        import torch
        z = torch.from_numpy(latent[None, :]).float()
        img = vae.decode(z)  # tensor [1,3,H,W]
        emb = clip.embed_image_from_rgb(img)  # tensor [1,D]
        emb_np = emb.detach().cpu().numpy()[0]
        fname = os.path.join(out_dir, f"pair_{i:06d}.npz")
        np.savez_compressed(fname, latent=latent, clip_emb=emb_np)
    print(f"Saved {n_pairs} pairs to {out_dir}")
