# experiments/benchmarks.py
"""
Run batched experiments across multiple model variants and collect metrics.
This is a lightweight orchestrator used to compare: SD-2.1, SD-3.5, SD-XL, Flux, EGP, Ablations.
For now modules use placeholder models; replace with real model wrappers.

Outputs:
  - CSV with columns: model, prompt_id, time_s, clip_score, neg_acc, notes
  - Aggregated summary printed.
"""
import time
import csv
import os
from pathlib import Path
import torch
import numpy as np

from configs import Config
from datasets.dataset_loader import prepare_generation_tasks, load_custom_prompt_set
from clip.clip_embedder import CLIPEmbedder
from sampling.ddim_wrapper import DDIMWrapper
from sampling.egp_sampler import None  # placeholder import guard; EGP sampler should be in place
# Above line purposely invalid to remind user to plug sampling/egp_sampler.py implementation
# For demo below we'll use a simplified sampling function

def simple_generate(prompt, model_variant="dummy", embedder=None, cfg=None):
    """
    Simplified generator: returns random latent, decoded image embedding and dummy neg_acc.
    Replace with actual call into sampling/egp_sampler.py
    """
    # simulate time
    t0 = time.time()
    B = 1
    latent_dim = 512
    device = cfg.device
    latent = torch.randn(B, latent_dim, device=device)
    # decode to rgb via dummy vae when needed and embed
    rgb = torch.rand(B, 3, 64, 64)
    img_emb = embedder.embed_image_from_rgb(rgb) if embedder is not None else torch.randn(B, 768)
    # compute fake metrics
    clip_score = float((img_emb * embedder._placeholder_text_to_vec([prompt]).to(img_emb.device)).sum(dim=1).mean().item()) if embedder else 0.0
    neg_acc = 0.95  # placeholder
    t_elapsed = time.time() - t0
    return {"latent": latent.cpu().numpy(), "clip_score": clip_score, "neg_acc": neg_acc, "time_s": t_elapsed}

def run_bench(prompt_list, out_csv="experiments/results.csv", models=None, cfg=None):
    cfg = cfg or Config()
    embedder = CLIPEmbedder(device=cfg.device)
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    models = models or ["SD-2.1", "SD-3.5", "SD-XL", "Flux", "EGP", "EGP-Ablation"]
    rows = []
    for model_name in models:
        for pid, p in enumerate(prompt_list):
            print(f"[{model_name}] generating prompt {pid}: {p[:60]}...")
            res = simple_generate(p, model_variant=model_name, embedder=embedder, cfg=cfg)
            rows.append([model_name, pid, p, res["time_s"], res["clip_score"], res["neg_acc"], "placeholder"])
            # flush occasionally
            if len(rows) % 10 == 0:
                with open(out_csv, "w", newline="", encoding="utf-8") as fh:
                    writer = csv.writer(fh)
                    writer.writerow(["model", "prompt_id", "prompt", "time_s", "clip_score", "neg_acc", "notes"])
                    writer.writerows(rows)
    # final write
    with open(out_csv, "w", newline="", encoding="utf-8") as fh:
        writer = csv.writer(fh)
        writer.writerow(["model", "prompt_id", "prompt", "time_s", "clip_score", "neg_acc", "notes"])
        writer.writerows(rows)
    print(f"Saved results to {out_csv}")

if __name__ == "__main__":
    cfg = Config()
    prompts = ["A photograph of a cat on a skateboard.", "A portrait of a woman in renaissance style."]
    run_bench(prompts, out_csv="experiments/results_demo.csv", cfg=cfg)
