import os
import json
import zipfile
from typing import Any, List, Tuple, Union

import torch
from tqdm.auto import tqdm
from diffusers import StableDiffusion3Pipeline
from PIL import Image


def repeat_ntimes(lst: List[Any], n: int) -> List[Any]:
    """
    Repeat each element of a list n times, preserving order.

    Example:
        lst = [a, b], n = 2 -> [a, a, b, b]
    """
    return [item for item in lst for _ in range(n)]


def update_concept_dict() -> dict:
    """
    Concept dictionary is preserved from your FLUX version.
    """
    concepts = [
        "female",
        "male",
        "young",
        "old",
        "white-race",
        "black-race",
        "anti-sexual",
        "smile",
        "jump",
        "van-gogh",
    ]
    return {name: idx for idx, name in enumerate(concepts)}


class SD35DataCreator:
    """
    Data generator for Stable Diffusion 3.5 (large / large-turbo).

    - Generates images from prompts.
    - Saves images as {idx}.jpg.
    - Saves labels.json and concept_dict.json
      in the same format as your FLUX pipeline.
    """

    def __init__(self, cfg):
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is not available. SD3.5 must run on GPU only.")

        self.device = "cuda"
        self.root_dir = cfg.root_dir
        self.model_variant = cfg.model_variant  # "large" or "large-turbo"

        # Prompts for image generation (possibly per-concept), repeated num_samples times.
        self.image_prompt: List[Union[str, Tuple[str, str], List[str]]] = repeat_ntimes(
            cfg.image_prompt, cfg.num_samples
        )

        # Labels (prompt + concept) repeated in the same way.
        self.input_prompt_and_target_concept: List[Any] = repeat_ntimes(
            cfg.input_prompt_and_target_concept, cfg.num_samples
        )

    # ------------------------------------------------------------------
    # Pipeline setup
    # ------------------------------------------------------------------
    def setup_pipeline(self) -> StableDiffusion3Pipeline:
        """
        Load Stable Diffusion 3.5 pipeline with sensible defaults.
        """
        if self.model_variant == "large":
            model_id = "stabilityai/stable-diffusion-3.5-large"
        elif self.model_variant == "large-turbo":
            model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
        else:
            raise ValueError(f"Unknown model_variant: {self.model_variant}")

        hf_token = os.getenv("HF_TOKEN")
        if hf_token is None:
            print("⚠️ HF_TOKEN is not set; if the model is gated, loading may fail.")

        pipe = StableDiffusion3Pipeline.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            token=hf_token,
        )

        pipe = pipe.to(self.device)

        # Memory optimizations
        if self.device == "cuda":
            pipe.enable_model_cpu_offload()
            if hasattr(pipe, "vae"):
                pipe.vae.enable_slicing()
                pipe.vae.enable_tiling()

        if hasattr(pipe, "set_progress_bar_config"):
            pipe.set_progress_bar_config(disable=True)

        return pipe

    # ------------------------------------------------------------------
    # Image generation
    # ------------------------------------------------------------------
    def create_images(
        self,
        num_inference_steps: int = None,
        guidance_scale: float = None,
        height: int = 1024,
        width: int = 1024,
    ) -> None:
        """
        Generate images with SD3.5, saving to self.root_dir.

        Defaults:
        - large:       40 steps, guidance ~4.5
        - large-turbo: 4 steps, guidance ~1.0
        """
        pipe = self.setup_pipeline()

        os.makedirs(self.root_dir, exist_ok=True)

        if self.model_variant == "large-turbo":
            # Few-step, distilled model (CFG baked in).
            num_inference_steps = num_inference_steps or 4
            guidance_scale = 1.0 if guidance_scale is None else guidance_scale
            max_seq_length = 512
        else:  # "large"
            num_inference_steps = num_inference_steps or 40
            guidance_scale = 4.5 if guidance_scale is None else guidance_scale
            max_seq_length = 512

        generator = torch.Generator(device=self.device).manual_seed(42)

        print(
            f"Using model_variant={self.model_variant}, "
            f"steps={num_inference_steps}, guidance={guidance_scale}, "
            f"max_seq_length={max_seq_length}, HxW={height}x{width}"
        )

        for idx, prompt in tqdm(
            enumerate(self.image_prompt),
            total=len(self.image_prompt),
            desc="Generating SD3.5 images",
        ):
            try:
                kwargs = dict(
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    max_sequence_length=max_seq_length,
                    generator=generator,
                    height=height,
                    width=width,
                    return_dict=True,
                )

                # Support (prompt, negative_prompt) structure
                if isinstance(prompt, (list, tuple)) and len(prompt) >= 2:
                    output = pipe(
                        prompt=prompt[0],
                        negative_prompt=prompt[1],
                        **kwargs,
                    )
                else:
                    # Simple string prompt or single-item list
                    if isinstance(prompt, (list, tuple)):
                        prompt_text = prompt[0]
                    else:
                        prompt_text = prompt

                    output = pipe(
                        prompt=prompt_text,
                        **kwargs,
                    )

                image: Image.Image = output.images[0]
                image_path = os.path.join(self.root_dir, f"{idx}.jpg")
                image.save(image_path)

                if idx % 50 == 0 and self.device == "cuda":
                    torch.cuda.empty_cache()

            except Exception as e:
                print(f"[ERROR] Failed to generate image {idx}: {e}")
                continue

        print(f"✅ Finished generating {len(self.image_prompt)} images into {self.root_dir}")

    # ------------------------------------------------------------------
    # Labels & concept dict
    # ------------------------------------------------------------------
    def create_labels(self) -> None:
        """
        Save labels.json and concept_dict.json in the same format
        as your FLUX data generator.
        """
        labels_path = os.path.join(self.root_dir, "labels.json")
        concept_dict_path = os.path.join(self.root_dir, "concept_dict.json")

        os.makedirs(self.root_dir, exist_ok=True)

        with open(labels_path, "w") as f:
            json.dump(self.input_prompt_and_target_concept, f, indent=2)

        concept_dict = update_concept_dict()
        with open(concept_dict_path, "w") as f:
            json.dump(concept_dict, f, indent=2)

        print(f"✅ Saved labels to {labels_path}")
        print(f"✅ Saved concept dict to {concept_dict_path}")

    # ------------------------------------------------------------------
    # Zipping
    # ------------------------------------------------------------------
    def zip_dataset(self) -> None:
        """
        Create <root_dir>.zip containing the entire dataset folder.
        """
        zip_path = f"{self.root_dir}.zip"
        print(f"Creating zip archive: {zip_path}")

        with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
            for foldername, _, filenames in os.walk(self.root_dir):
                for filename in filenames:
                    file_path = os.path.join(foldername, filename)
                    arcname = os.path.relpath(file_path, os.path.dirname(self.root_dir))
                    zipf.write(file_path, arcname)

        size_mb = os.path.getsize(zip_path) / (1024 * 1024)
        print(f"✅ Created zip archive: {zip_path} ({size_mb:.2f} MB)")

    # ------------------------------------------------------------------
    # Orchestration
    # ------------------------------------------------------------------
    def run(
        self,
        num_inference_steps: int = None,
        guidance_scale: float = None,
        create_zip: bool = True,
        height: int = 1024,
        width: int = 1024,
    ) -> None:
        """
        Full pipeline: labels -> images -> optional zip.
        """
        self.create_labels()
        self.create_images(
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
        )
        if create_zip:
            self.zip_dataset()


# ----------------------------------------------------------------------
# Configurations (mirroring your Flux configs)
# ----------------------------------------------------------------------


class CfgLargeSingle:
    """
    Single concept: female.

    - model_variant: "large" (high quality)
    """

    model_variant = "large"
    root_dir = "datasets/person_sd35_large_female"
    num_samples = 100

    image_prompt = [
        "a high quality photo of a woman, natural light, realistic skin texture"
    ]

    # [[[prompt, [concepts]]], ...]
    input_prompt_and_target_concept = [
        [["a photo of a person", ["female"]]],
    ]

    validation_prompt_and_target_concept = ["a photo of a person", ["female"]]


class CfgLargeBatch:
    """
    Batch of gender + age concepts.
    """

    model_variant = "large"
    root_dir = "datasets/person_batch_sd35_large"
    num_samples = 100  # per concept

    image_prompt = [
        # female
        "a high quality photo of a woman, natural light, realistic skin texture",
        # male
        "a high quality photo of a man, natural light, realistic skin texture",
        # young
        "a high quality photo of a young adult, natural light, realistic skin texture",
        # old
        "a high quality photo of an elderly person, natural light, realistic skin texture",
    ]

    input_prompt_and_target_concept = [
        [["a photo of a person", ["female"]]],
        [["a photo of a person", ["male"]]],
        [["a photo of a person", ["young"]]],
        [["a photo of a person", ["old"]]],
    ]

    validation_concepts = ["female", "male", "young", "old"]


class CfgLargeRace:
    """
    Race concepts: white-race and black-race.
    """

    model_variant = "large"
    root_dir = "datasets/person_race_sd35_large"
    num_samples = 100

    image_prompt = [
        "a high quality photorealistic portrait of a person with light skin tone",
        "a high quality photorealistic portrait of a person with dark skin tone",
    ]

    input_prompt_and_target_concept = [
        [["a photo of a person", ["white-race"]]],
        [["a photo of a person", ["black-race"]]],
    ]

    validation_concepts = ["white-race", "black-race"]


class CfgLargeFull:
    """
    Full demographic set: gender, age, race, smile.
    """

    model_variant = "large"
    root_dir = "datasets/person_full_sd35_large"
    num_samples = 100

    image_prompt = [
        # female
        "a high quality photo of a woman, natural light, realistic skin texture",
        # male
        "a high quality photo of a man, natural light, realistic skin texture",
        # young
        "a high quality photo of a young adult, natural light, realistic skin texture",
        # old
        "a high quality photo of an elderly person, natural light, realistic skin texture",
        # white-race
        "a high quality photorealistic portrait of a person with light skin tone",
        # black-race
        "a high quality photorealistic portrait of a person with dark skin tone",
        # smile
        "a high quality photo of a person, big smile, showing teeth, natural light",
    ]

    input_prompt_and_target_concept = [
        [["a photo of a person", ["female"]]],
        [["a photo of a person", ["male"]]],
        [["a photo of a person", ["young"]]],
        [["a photo of a person", ["old"]]],
        [["a photo of a person", ["white-race"]]],
        [["a photo of a person", ["black-race"]]],
        [["a photo of a person", ["smile"]]],
    ]

    validation_concepts = [
        "female",
        "male",
        "young",
        "old",
        "white-race",
        "black-race",
        "smile",
    ]


class CfgTurboSingle:
    """
    Single concept using SD3.5 large-turbo (few-step, fast).
    """

    model_variant = "large-turbo"
    root_dir = "datasets/person_sd35_turbo_nude"
    num_samples = 3000

    image_prompt = [
        "a photo of a woman",
    ]

    input_prompt_and_target_concept = [
        [["a photo of a person", ["female"]]],
    ]

    validation_prompt_and_target_concept = ["a photo of a person", ["female"]]


if __name__ == "__main__":
    # Choose one config:
    # config = CfgLargeSingle
    # config = CfgLargeBatch
    # config = CfgLargeRace
    # config = CfgLargeFull
    config = CfgTurboSingle

    creator = SD35DataCreator(config)
    creator.run(
        # num_inference_steps=...,  # override if needed
        # guidance_scale=...,       # override if needed
        create_zip=True,
        height=1024,
        width=1024,
    )
