# datasets/dataset_loader.py
"""
Dataset loader utilities for COCO-val / prompt datasets and custom prompt lists.
This module assumes datasets are available locally. It provides:
  - loaders that return lists of prompts, optional target captions, and config splits.
  - utility to load a directory of prompt text files.

You must supply dataset files locally per paper protocol.
"""
import os
import json
import random

def load_prompt_file(path):
    """
    Load a newline-separated text file of prompts.
    Returns list[str]
    """
    with open(path, "r", encoding="utf-8") as fh:
        lines = [l.strip() for l in fh if l.strip()]
    return lines

def load_coco_val_prompts(coco_dir):
    """
    Expect a file with prompts or captions in coco_dir (e.g., captions_val2017.json).
    For demo this function returns placeholder prompts if file not present.
    """
    captions_file = os.path.join(coco_dir, "captions_val2017.json")
    prompts = []
    if os.path.exists(captions_file):
        with open(captions_file, "r", encoding="utf-8") as fh:
            data = json.load(fh)
            for ann in data.get("annotations", []):
                prompts.append(ann.get("caption", ""))
    else:
        # fallback synthetic prompts
        prompts = ["A photo of a cat.", "A scenic landscape.", "An astronaut riding a horse."]
    return prompts

def load_custom_prompt_set(prompt_path):
    """
    Accepts either a directory containing .txt files per prompt or a single .txt file.
    """
    if os.path.isdir(prompt_path):
        prompts = []
        for fname in sorted(os.listdir(prompt_path)):
            if fname.lower().endswith(".txt"):
                prompts.extend(load_prompt_file(os.path.join(prompt_path, fname)))
        return prompts
    else:
        return load_prompt_file(prompt_path)

def prepare_generation_tasks(prompts, n_per_prompt=1, rng_seed=42):
    """
    Prepare a list of generation tasks with prompt text and a unique id.
    Returns list of dicts: [{'id': idx, 'prompt': text}, ...]
    """
    random.seed(rng_seed)
    tasks = []
    idx = 0
    for p in prompts:
        for _ in range(n_per_prompt):
            tasks.append({"id": idx, "prompt": p})
            idx += 1
    random.shuffle(tasks)
    return tasks
