# experiments/generate_targets.py
"""
Generate target latents x_g by ancestral sampling from positive prompts.
These targets serve as fidelity anchors in the EGP energy.

This script assumes access to a generative model or a saved function that can produce latents.
For demo a random latent is used. Replace `ancestral_sample_latent` with real sampler.
"""
import os
import json
import torch
import numpy as np
from pathlib import Path
from configs import Config
from datasets.dataset_loader import load_custom_prompt_set

def ancestral_sample_latent(prompt, n_samples=1, latent_dim=512, device="cpu"):
    """
    Placeholder ancestral sampling: return random latents.
    Replace with model-specific ancestral sampler to reproduce paper's xg generation.
    """
    return np.random.randn(n_samples, latent_dim).astype(np.float32)

def generate_targets(prompt_file, out_dir="experiments/targets", per_prompt=1):
    os.makedirs(out_dir, exist_ok=True)
    prompts = load_custom_prompt_set(prompt_file)
    saved = []
    for i, p in enumerate(prompts):
        latents = ancestral_sample_latent(p, n_samples=per_prompt)
        for j, lat in enumerate(latents):
            fname = os.path.join(out_dir, f"target_{i}_{j}.npz")
            np.savez_compressed(fname, prompt=p, latent=lat)
            saved.append(fname)
    print(f"Saved {len(saved)} target latents to {out_dir}")
    return saved

if __name__ == "__main__":
    # demo usage
    demo_prompts = ["A closeup photograph of a sunflower.", "An illustration of a futuristic city skyline."]
    # save a temporary prompt file
    pf = "experiments/demo_prompts.txt"
    Path("experiments").mkdir(exist_ok=True)
    with open(pf, "w", encoding="utf-8") as fh:
        fh.write("\n".join(demo_prompts))
    generate_targets(pf, per_prompt=2)
