import os
import torch
from diffusers import StableDiffusionPipeline, DiffusionPipeline
# from rtpt import RTPT
from torch import autocast
from transformers import CLIPTextModel, set_seed
import pandas as pd
from argparse import ArgumentParser
from utils.misc import fix_seed
from const import *


# HF_TOKEN = 'INSERT_HF_TOKEN'
# OUTPUT_FOLDER = 'images'
NUM_SAMPLES = 9
fix_seed(42)

# dialect type is the second last path of the prompt directory
# DIALECT_TYPE = DATA_FILE.split("/")[-2]


def main(args):
    if "stable-diffusion-xl" in args.model:
        pipe = DiffusionPipeline.from_pretrained(
            args.model,
            # torch_dtype=torch.float16, 
            use_safetensors=True, 
            # variant="fp16"
        ).to("cuda")
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            args.model,
        ).to("cuda")
    # pipe.safety_checker = None  # disable safety checker if desired
    if args.swap:
        base_dir = BASE_SWAP_DIR
        text_encoder = CLIPTextModel.from_pretrained(args.encoder, use_safetensors=True, device_map="auto")
        pipe.text_encoder = text_encoder
    else:
        base_dir = BASE_ORIG_DIR

    # df = pd.read_csv(DATA_FILE, encoding="unicode_escape")
    data_path = os.path.join(args.data_dir, args.dialect, "test.csv")
    df = pd.read_csv(data_path, encoding="unicode_escape")
    dialect_prompts = df["Dialect_Prompt"].tolist()
    sae_prompts = df["SAE_Prompt"].tolist()

    print(">>> encoder name: " + args.encoder.split("/")[-1])
    for i in range(len(dialect_prompts)):
        dialect_prompt = dialect_prompts[i]
        sae_prompt = sae_prompts[i]
        
        model_base_name = args.model.split("/")[-1] + "/"
        if "best" in args.encoder or "last" in args.encoder:
            model_base_name += "/".join(args.encoder.split("/")[-2:])
        else:
            model_base_name += args.encoder.split("/")[-1]
        if not args.swap:
            model_base_name = args.model.split("/")[-1]
        
        dialect_dir = os.path.join(base_dir, model_base_name, args.mode, args.dialect, dialect_prompt)
        sae_dir = os.path.join(base_dir, model_base_name, args.mode, f"{args.dialect}_sae", sae_prompt)
        if os.path.isdir(dialect_dir) and os.path.isdir(sae_dir):
            dialect_imgs_exist = all(os.path.isfile(os.path.join(dialect_dir, f"{k}.jpg")) for k in range(NUM_SAMPLES))
            sae_imgs_exist = all(os.path.isfile(os.path.join(sae_dir, f"{k}.jpg")) for k in range(NUM_SAMPLES))
            if dialect_imgs_exist and sae_imgs_exist:
                continue
        os.makedirs(dialect_dir, exist_ok=True)
        os.makedirs(sae_dir, exist_ok=True)
        
        for k in range(NUM_SAMPLES):
            ## DIALECT
            # with autocast("cuda"):
            image = pipe(dialect_prompt).images[0]
            image_path = os.path.join(dialect_dir, f"{k}.jpg")
            image.save(image_path)
            
            ## SAE
            # with autocast("cuda"):
            image = pipe(sae_prompt).images[0]
            image_path = os.path.join(sae_dir, f"{k}.jpg")
            image.save(image_path)


def parse_arguments():
    parser = ArgumentParser(description="Generate images using a stable diffusion model.")
    parser.add_argument("--model", type=str, default="stable-diffusion-v1-5/stable-diffusion-v1-5", 
                        choices=["stabilityai/stable-diffusion-2-1", "stable-diffusion-v1-5/stable-diffusion-v1-5",
                                 "stabilityai/stable-diffusion-xl-base-1.0"])
    parser.add_argument("--encoder", type=str, default="models/sge/singlish_kl_iac_20ep")
    # parser.add_argument("--swap", action="store_true", help="Swap in the trained text encoder.")
    parser.add_argument("--swap", type=int, default=0, help="Swap in the trained text encoder.")
    parser.add_argument("--data_dir", type=str, default="./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/concise/")
    parser.add_argument("--mode", type=str, default="concise")
    parser.add_argument("--dialect", type=str, default="sge")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_arguments()
    args.data_dir = f"./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/{args.mode}"
    main(args)
