import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from diffusers import AutoPipelineForText2Image, DiffusionPipeline

# ================================
# PARAMETERS TO SET HERE
# ================================
model_names = [
    "Koala",
    "Sana",
    "LCM",
    "Unidiffuser",
    "SDXL-Turbo",
    "SSD-1B"
]
num_generations = 5
T = 1000
output_dir = "flowers"
prompts_path = "prompts/flowers.json"

# Initialize device and CLIP
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def compute_clip_score(prompt, image):
    inputs = clip_processor(text=[prompt], images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = clip_model(**inputs)
    img_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    txt_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
    return 100 * (img_emb * txt_emb).sum(dim=-1).item()

def is_black_image(image, threshold=5):
    grayscale = image.convert("L")
    hist = grayscale.histogram()
    avg_pixel = sum(i * hist[i] for i in range(256)) / sum(hist)
    return avg_pixel < threshold

def get_generator(model_name):
    """Return a lambda function generating an image for the given model."""
    if model_name == "Sana":
        from diffusers import SanaPipeline
        pipe = SanaPipeline.from_pretrained(
            "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
            torch_dtype=torch.bfloat16
        ).to(device)
        pipe.vae.to(torch.bfloat16)
        pipe.text_encoder.to(torch.bfloat16)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(
            prompt=prompt,
            height=1024,
            width=1024,
            guidance_scale=4.5,
            num_inference_steps=10
        ).images[0]

    if model_name == "LCM":
        pipe = DiffusionPipeline.from_pretrained(
            "SimianLuo/LCM_Dreamshaper_v7",
            torch_dtype=torch.float16
        ).to(device)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(
            prompt=prompt,
            num_inference_steps=4,
            guidance_scale=8.0,
            lcm_origin_steps=50,
            output_type="pil"
        ).images[0]

    if model_name == "Unidiffuser":
        pipe = DiffusionPipeline.from_pretrained(
            "thu-ml/unidiffuser-v1",
            torch_dtype=torch.float16
        ).to(device)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(
            prompt=prompt,
            height=512,
            width=512,
            num_inference_steps=10
        ).images[0]

    if model_name == "SDXL-Turbo":
        pipe = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sdxl-turbo",
            torch_dtype=torch.float16
        ).to(device)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(
            prompt=prompt,
            num_inference_steps=2,
            guidance_scale=0.0
        ).images[0]

    if model_name == "SSD-1B":
        pipe = DiffusionPipeline.from_pretrained(
            "segmind/SSD-1B",
            torch_dtype=torch.float16
        ).to(device)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(
            prompt=prompt,
            num_inference_steps=10,
            guidance_scale=7.5
        ).images[0]
    
    if model_name == "Koala":
        model_id = "etri-vilab/koala-lightning-700m"
        pipe = DiffusionPipeline.from_pretrained(
            model_id, 
            torch_dtype=torch.float16,
            variant="fp16"
        ).to(device)
        pipe.set_progress_bar_config(disable=True)
        return lambda prompt: pipe(prompt, num_inference_steps=8).images[0]

    raise ValueError(f"Unrecognized model: {model_name}")


def main():
    torch.cuda.empty_cache()
    os.makedirs(output_dir, exist_ok=True)
    metadata_path = os.path.join(output_dir, "metadata.json")

    # Load or initialize metadata
    if os.path.exists(metadata_path):
        with open(metadata_path, "r", encoding="utf-8") as f:
            all_metadata = json.load(f)
    else:
        all_metadata = []

    existing_prompts_global = {entry["prompt"] for entry in all_metadata}

    for model_name in model_names:
        torch.cuda.empty_cache()
        print(f"\n▶ Starting generation for model: {model_name}")
        generator = get_generator(model_name)
        model_dir = os.path.join(output_dir, model_name)
        os.makedirs(model_dir, exist_ok=True)

        # Prompts already done by this model and by others
        done_by_model = {e["prompt"] for e in all_metadata if e["model"] == model_name}
        done_by_others = {e["prompt"] for e in all_metadata if e["model"] != model_name}

        # 1) Complete on prompts produced by others
        missing_prompts = sorted(done_by_others - done_by_model)

        # 2) If insufficient, draw from list of new prompts
        if len(missing_prompts) < T:
            with open(prompts_path, "r", encoding="utf-8") as f:
                new_prompts = json.load(f)
            # take those not already in all_metadata
            candidates = [p for p in new_prompts if p not in existing_prompts_global]
            missing_prompts.extend(candidates)
            missing_prompts = missing_prompts[:T]

        print(f"→ {len(missing_prompts)} prompts to generate for {model_name}")

        # Calculate starting index for file numbering
        existing_files = [f for f in os.listdir(model_dir) if f.endswith(".png")]
        idx = max([int(f.split("_")[1].split(".")[0]) for f in existing_files], default=-1) + 1

        for prompt in tqdm(missing_prompts[:T], desc=model_name):
            filenames, scores = [], []
            for _ in range(num_generations):
                try:
                    img = generator(prompt)
                    if is_black_image(img):
                        continue
                    fname = f"{model_name}/img_{idx:05d}.png"
                    path = os.path.join(output_dir, fname)
                    img.save(path)
                    score = compute_clip_score(prompt, img)
                    filenames.append(fname)
                    scores.append(round(score, 2))
                    idx += 1
                except Exception as e:
                    print(f"Error ({model_name}) on \"{prompt}\": {e}")
                    continue
            if filenames:
                all_metadata.append({
                    "prompt": prompt,
                    "model": model_name,
                    "filenames": filenames,
                    "clip_scores": scores
                })
                existing_prompts_global.update(filenames)

    # Save updated metadata
    with open(metadata_path, "w", encoding="utf-8") as f:
        json.dump(all_metadata, f, indent=2, ensure_ascii=False)

    print(f"\n✅ Generation completed for all models. Total entries: {len(all_metadata)}")

if __name__ == "__main__":
    main()