import os
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline
# from rtpt import RTPT
from torch import autocast
from transformers import CLIPTextModel, CLIPTextModelWithProjection
import pandas as pd
from argparse import ArgumentParser
from utils.hf_captions import create_hf_coco_dataset
from utils.misc import fix_seed
from const import *

NUM_SAMPLES = 9
fix_seed(42)


def main(args):
    mscoco = create_hf_coco_dataset(CAPTION_FILE_PATH, IMAGE_FOLDER_PATH).select(range(4950, 5000))
    pipe = DiffusionPipeline.from_pretrained(
        args.model,
        # torch_dtype=torch.float16, 
        use_safetensors=True, 
        # variant="fp16"
    ).to("cuda")
    # pipe.safety_checker = None  # disable safety checker if desired
    
    if args.swap:
        base_dir = MSCOCO_BASE_SWAP_DIR
        text_encoder = CLIPTextModel.from_pretrained(args.encoder_1, use_safetensors=True, device_map="auto")
        text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.encoder_2, use_safetensors=True, device_map="auto")
        pipe.text_encoder = text_encoder
        pipe.text_encoder_2 = text_encoder_2
    else:
        base_dir = MSCOCO_BASE_ORIG_DIR

    prompts = [ct[0] for ct in mscoco["captions"]]

    for prompt in prompts:
        # model_base_name = args.model.split("/")[-1]
        model_base_name = args.model.split("/")[-1] + "/"
        model_base_name += "/".join(args.encoder_1.split("/")[-2:]) + "/" + "/".join(args.encoder_2.split("/")[-2:])
        if not args.swap:
            model_base_name = args.model.split("/")[-1]
        
        prompt_dir = os.path.join(base_dir, model_base_name, prompt)
        if os.path.isdir(prompt_dir):
            prompt_exist = all(os.path.isfile(os.path.join(prompt_dir, f"{k}.jpg")) for k in range(NUM_SAMPLES))
            if not args.replace and prompt_exist:
                continue
        os.makedirs(prompt_dir, exist_ok=True)
        
        for k in range(NUM_SAMPLES):
            # with autocast("cuda"):
            image = pipe(prompt).images[0]
            image_path = os.path.join(prompt_dir, f"{k}.jpg")
            image.save(image_path)


def parse_arguments():
    parser = ArgumentParser(description="Generate images using a stable diffusion model.")
    parser.add_argument("--model", type=str, default="stable-diffusion-v1-5/stable-diffusion-v1-5", 
                        choices=["stabilityai/stable-diffusion-2-1", "stable-diffusion-v1-5/stable-diffusion-v1-5",
                                 "stabilityai/stable-diffusion-xl-base-1.0"])
    parser.add_argument("--encoder_1", type=str, default="models/sge/singlish_kl_iac_20ep")
    parser.add_argument("--encoder_2", type=str, default="models/sge/singlish_kl_iac_20ep")
    # parser.add_argument("--swap", action="store_true", help="Swap in the trained text encoder.")
    parser.add_argument("--swap", type=int, default=0, help="Swap in the trained text encoder.")
    parser.add_argument('--replace', type=int, default=0)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    main(args)
