import os
import torch
from pathlib import Path
from diffusers import StableDiffusionPipeline, DiffusionPipeline
# from rtpt import RTPT
from torch import autocast
from transformers import CLIPTextModel, set_seed
import pandas as pd
import shutil
from argparse import ArgumentParser
from utils.misc import fix_seed
from const import *


# HF_TOKEN = 'INSERT_HF_TOKEN'
# OUTPUT_FOLDER = 'images'
NUM_SAMPLES = 9
fix_seed(42)

# dialect type is the second last path of the prompt directory
# DIALECT_TYPE = DATA_FILE.split("/")[-2]

def generate_stable_diffusion_batch(pipe, prompt: str, num_images: int) -> list:
    """
    Generates a specified number of images using Stable Diffusion.
    Images are generated in batches (up to 3 at a time).
    """
    images = []
    while len(images) < num_images:
        batch_size = min(3, num_images - len(images))
        # Generate a row of images; note that the pipeline call returns a list of images.
        result = pipe([prompt] * batch_size).images
        images.extend(result)
    return images[:num_images]

def main(args):
    if "stable-diffusion-xl" in args.model:
        pipe = DiffusionPipeline.from_pretrained(
            args.model,
            torch_dtype=torch.float16, 
            use_safetensors=True, 
            variant="fp16"
        ).to("cuda")
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            args.model,
        ).to("cuda")
    # pipe.safety_checker = None  # disable safety checker if desired
    if args.swap:
        base_dir = BASE_SWAP_DIR
        text_encoder = CLIPTextModel.from_pretrained(args.encoder, use_safetensors=True, device_map="auto")
        pipe.text_encoder = text_encoder
    else:
        base_dir = BASE_ORIG_DIR
    base_dir = Path(base_dir)

    # df = pd.read_csv(DATA_FILE, encoding="unicode_escape")
    data_path = os.path.join(args.data_dir, args.dialect, "test.csv")
    df = pd.read_csv(data_path, encoding="unicode_escape")
    dialect_prompts = df["Dialect_Prompt"].tolist()
    sae_prompts = df["SAE_Prompt"].tolist()

    print(">>> encoder name: " + args.encoder.split("/")[-1])
    for i in range(len(dialect_prompts)):
        dialect_prompt = dialect_prompts[i]
        sae_prompt = sae_prompts[i]
        
        model_base_name = args.model.split("/")[-1] + "/"
        if "best" in args.encoder or "last" in args.encoder:
            model_base_name += "/".join(args.encoder.split("/")[-2:])
        else:
            model_base_name += args.encoder.split("/")[-1]
        if not args.swap:
            model_base_name = args.model.split("/")[-1]
        model_base_name_path = Path(model_base_name)
        
        # dialect_dir = os.path.join(base_dir, model_base_name, args.mode, args.dialect, dialect_prompt)
        # sae_dir = os.path.join(base_dir, model_base_name, args.mode, f"{args.dialect}_sae", sae_prompt)
        dialect_dir = base_dir / model_base_name_path / args.mode / args.dialect / dialect_prompt
        sae_dir = base_dir / model_base_name_path / args.mode / f"{args.dialect}_sae" / sae_prompt

        ## DIALECT
        if args.replace:
            if dialect_dir.exists():
                shutil.rmtree(dialect_dir)
            dialect_dir.mkdir(parents=True, exist_ok=True)
            missing_indices = list(range(NUM_SAMPLES))
        else:
            if not dialect_dir.exists():
                dialect_dir.mkdir(parents=True, exist_ok=True)
                missing_indices = list(range(NUM_SAMPLES))
            else:
                missing_indices = [i for i in range(NUM_SAMPLES) if not (dialect_dir / f"{i}.jpg").is_file()]

        if missing_indices:
            num_missing = len(missing_indices)
            # print(f"Generating {num_missing} image(s) for prompt '{dialect_prompt}' in folder '{dialect_dir}'.")
            new_images = generate_stable_diffusion_batch(pipe, dialect_prompt, num_missing)
            for idx, image in zip(missing_indices, new_images):
                image.save(str(dialect_dir / f"{idx}.jpg"))

        ## SAE
        if args.replace:
            if sae_dir.exists():
                shutil.rmtree(sae_dir)
            sae_dir.mkdir(parents=True, exist_ok=True)
            missing_indices = list(range(NUM_SAMPLES))
        else:
            if not sae_dir.exists():
                sae_dir.mkdir(parents=True, exist_ok=True)
                missing_indices = list(range(NUM_SAMPLES))
            else:
                missing_indices = [i for i in range(NUM_SAMPLES) if not (sae_dir / f"{i}.jpg").is_file()]

        if missing_indices:
            num_missing = len(missing_indices)
            # print(f"Generating {num_missing} image(s) for prompt '{dialect_prompt}' in folder '{sae_dir}'.")
            new_images = generate_stable_diffusion_batch(pipe, sae_prompt, num_missing)
            for idx, image in zip(missing_indices, new_images):
                image.save(str(sae_dir / f"{idx}.jpg"))


def parse_arguments():
    parser = ArgumentParser(description="Generate images using a stable diffusion model.")
    parser.add_argument("--model", type=str, default="stable-diffusion-v1-5/stable-diffusion-v1-5", 
                        choices=["stabilityai/stable-diffusion-2-1", "stable-diffusion-v1-5/stable-diffusion-v1-5",
                                 "stabilityai/stable-diffusion-xl-base-1.0"])
    parser.add_argument("--encoder", type=str, default="models/sge/singlish_kl_iac_20ep")
    # parser.add_argument("--swap", action="store_true", help="Swap in the trained text encoder.")
    parser.add_argument("--swap", type=int, default=0, help="Swap in the trained text encoder.")
    parser.add_argument("--data_dir", type=str, default="./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/concise/")
    parser.add_argument("--mode", type=str, default="concise")
    parser.add_argument("--dialect", type=str, default="sge")
    parser.add_argument("--replace", type=int, default=0)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_arguments()
    args.data_dir = f"./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/{args.mode}"
    main(args)
