# train/train_projector.py
"""
Training script for the MLP projector P_phi (latent -> CLIP embedding).
Expect training data saved as .npz files containing pairs (latent, clip_emb).

Usage:
  python train/train_projector.py --data_dir /path/to/pairs --epochs 20 --out model.pt
"""
import argparse
import os
import glob
import numpy as np
import torch
from projector.mlp_projector import MLPProjector

def load_pairs_from_dir(data_dir, max_samples=None):
    files = glob.glob(os.path.join(data_dir, "*.npz"))
    latents = []
    clip_embs = []
    for f in files:
        d = np.load(f)
        if "latent" in d and "clip_emb" in d:
            latents.append(d["latent"])
            clip_embs.append(d["clip_emb"])
    if len(latents) == 0:
        raise RuntimeError("No training pairs found in data_dir (expect .npz with 'latent' and 'clip_emb').")
    latents = np.vstack(latents)
    clip_embs = np.vstack(clip_embs)
    if max_samples:
        latents = latents[:max_samples]
        clip_embs = clip_embs[:max_samples]
    return latents, clip_embs

def main(args):
    lat, clip = load_pairs_from_dir(args.data_dir, max_samples=args.max_samples)
    device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
    model = MLPProjector(latent_dim=lat.shape[1], clip_dim=clip.shape[1], hidden_dim=args.hidden).to(device)
    model = MLPProjector.train_from_pairs(lat, clip, device=device, epochs=args.epochs, lr=args.lr, batch_size=args.batch)
    os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
    torch.save(model.state_dict(), args.out)
    print(f"Saved projector to {args.out}")

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", required=False, default="projector/pairs", help="dir with .npz latent/clip pairs")
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--batch", type=int, default=256)
    p.add_argument("--hidden", type=int, default=512)
    p.add_argument("--out", default="projector/projector.pth")
    p.add_argument("--max_samples", type=int, default=None)
    p.add_argument("--force_cpu", action="store_true")
    args = p.parse_args()
    main(args)
