import os
import pickle
import random
from tqdm import tqdm
import submitit

import torch
from transformers import set_seed

from SDXLPipeline import StableDiffusionXLPipeline
from SD15Pipeline import StableDiffusionPipeline
from SD3Pipeline import StableDiffusion3Pipeline
from diffusers import EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler

import argparse
from PIL import Image
import numpy as np
import shutil
import zipfile

DATA_PATH = "../metadata/cc12m/"
SAVE_PATH = "../outputs/SDinference/"

scheduler_type = {
    "EulerDiscrete": EulerDiscreteScheduler,
    "FMEulerDiscrete": FlowMatchEulerDiscreteScheduler,
}

model_path_dict = {
    "SD15": "stable-diffusion-v1-5/stable-diffusion-v1-5",
    "SDXL": "stabilityai/stable-diffusion-xl-base-1.0",
    "SD35L": "stabilityai/stable-diffusion-3.5-large",
    "SD35M": "stabilityai/stable-diffusion-3.5-medium",
}


def resize_img(foo):
    foo = foo.resize((256, 256), Image.LANCZOS)
    return foo


def get_common_element(lst):
    return max(set(lst), key=lst.count)


def inference(args, selected_captions, batch_size, chunk_index):
    precision = torch.float16
    # Ensure reproducibility for Hugging Face diffusers and transformers
    slurm_array_task_id = int(os.environ.get('SLURM_ARRAY_TASK_ID', 0))
    random_seed = args.random + slurm_array_task_id
    set_seed(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed)

    save_folder = f"/tmp/jobid_{os.environ['SLURM_JOB_ID']}/"
    os.makedirs(save_folder, exist_ok=False)
    os.makedirs(f"{save_folder}/generations", exist_ok=False)
    # Load Model #
    if args.model == 'SDXL':
        pipe = StableDiffusionXLPipeline.from_pretrained(
             model_path_dict[args.model],
             torch_dtype=precision,
        )
    elif args.model.startswith('SD35'):
        pipe = StableDiffusion3Pipeline.from_pretrained(
                model_path_dict[args.model],
                torch_dtype=precision,
            )
        # if args.model == "SD35L":
        #     pipe.dec_batch_size = 10
    else:
        assert args.model == "SD15"
        pipe = StableDiffusionPipeline.from_pretrained(
             model_path_dict[args.model],
             torch_dtype=precision,
        )

    pipe.safety_checker = None

    # Set Scheduler #
    if "SD35" in args.model:
        pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
                    pipe.scheduler.config
                )
    else:
        pipe.scheduler = EulerDiscreteScheduler.from_config(
                    pipe.scheduler.config,
                )

    pipe = pipe.to("cuda")

    g_gpu = torch.Generator(device='cuda')
    g_gpu.manual_seed(args.random)

    if args.t_guidance == "APG":
        if args.model == "SD15":
            guidance_kwargs = {
                "adaptive_projected_guidance": True,
                "adaptive_projected_guidance_momentum": -0.75,
                "adaptive_projected_guidance_rescale_factor": 7.5,
                "interval_guidance": False,
                "cads_guidance": False,
            }
        elif args.model == "SDXL":
            guidance_kwargs = {
                "adaptive_projected_guidance": True,
                "adaptive_projected_guidance_momentum": -0.5,
                "adaptive_projected_guidance_rescale_factor": 15,
                "interval_guidance": False,
                "cads_guidance": False,
            }
        elif args.model in ["SD35L", "SD35M"]:
            guidance_kwargs = {
                "adaptive_projected_guidance": True,
                "adaptive_projected_guidance_momentum": -0.5,
                "adaptive_projected_guidance_rescale_factor": 10,
                "interval_guidance": False,
                "cads_guidance": False,
            }
        else:
            raise ValueError("SD15, XL, 35L and 35M only")
    elif args.t_guidance == "Interval":
        if "SD35" not in args.model:
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": True,
                "cads_guidance": False,
                "cads_tau_1": 0.08,
                "cads_tau_2": 0.81,
            }
        else:
            assert args.model in ["SD35L", "SD35M"]
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": True,
                "cads_guidance": False,
                "cads_tau_1": 0.3,
                "cads_tau_2": 0.95,
            }
    elif args.t_guidance == "CADS":
        if args.model == "SD15":
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": False,
                "cads_guidance": True,
                "cads_rescale": True,
                "cads_tau_1": 0.8,
                "cads_tau_2": 1.3,
                "cads_noise_scale": 0.1,
                "cads_mixing_factor": 1.0,
                "cads_hr_fix_active": False,
            }
        elif args.model == "SDXL":
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": False,
                "cads_guidance": True,
                "cads_rescale": True,
                "cads_tau_1": 0.6,
                "cads_tau_2": 1.0,
                "cads_noise_scale": 0.3,
                "cads_mixing_factor": 1.0,
                "cads_hr_fix_active": False,
            }
        elif args.model == "SD35M":
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": False,
                "cads_guidance": True,
                "cads_rescale": True,
                "cads_tau_1": 0.85,
                "cads_tau_2": 1.25,
                "cads_noise_scale": 0.3,
                "cads_mixing_factor": 1.0,
                "cads_hr_fix_active": False,
            }
        elif args.model == "SD35L":
            guidance_kwargs = {
                "adaptive_projected_guidance": False,
                "interval_guidance": False,
                "cads_guidance": True,
                "cads_rescale": True,
                "cads_tau_1": 0.85,
                "cads_tau_2": 1.25,
                "cads_noise_scale": 0.3,
                "cads_mixing_factor": 1.0,
                "cads_hr_fix_active": False,
            }
        else:
            raise ValueError("SD15, XL, 35L and 35M only")
    else:
        guidance_kwargs = {
            "adaptive_projected_guidance": False,
            "interval_guidance": False,
            "cads_guidance": False,
        }

    for _, (key, prompt) in enumerate(tqdm(selected_captions)):
        if not isinstance(prompt, list):
            prompt = [prompt]
        if args.model == "SD35L":
            imgs = []
            images = pipe(prompt=prompt,
                          num_images_per_prompt=batch_size,
                          generator=g_gpu,
                          num_inference_steps=28,
                          guidance_scale=args.guidance,
                          **guidance_kwargs,).images
            imgs.extend(images)
        elif args.model == "SD35M":
            imgs = []
            images = pipe(prompt=prompt,
                          num_images_per_prompt=batch_size,
                          generator=g_gpu,
                          num_inference_steps=28,
                          guidance_scale=args.guidance,
                          **guidance_kwargs,).images
            imgs.extend(images)
        elif args.model == "SDXL":
            imgs = []
            images = pipe(prompt=prompt,
                          num_images_per_prompt=batch_size,
                          guidance_scale=args.guidance,
                          num_inference_steps=28,
                          generator=g_gpu,
                          **guidance_kwargs,).images
            imgs.extend(images)
        else:
            assert args.model == "SD15"
            imgs = []
            images = pipe(prompt=prompt,
                          num_images_per_prompt=batch_size,
                          guidance_scale=args.guidance,
                          num_inference_steps=28,
                          generator=g_gpu,
                          **guidance_kwargs,).images
            imgs.extend(images)
        for img_index, img in enumerate(imgs):
            img = resize_img(img)
            image_path = (
                    f"{save_folder}/"
                    f"generations/"
                    f"{img_index + key * 20}.png"
                )
            img.save(image_path)

    generations_folder = (
        f"{SAVE_PATH}/"
        f"SDinference_gemma3_siglip_clip_{args.scheduler}_{args.t_guidance}/"
        f"guidance_{args.guidance:.01f}"
        f"/model_{args.model}/"
        f"complexity_{args.complexity}/"
    )

    if chunk_index == 0:
        print("Copying images to generations folder...")
        for i in range(0, 100, 20):
            shutil.copy(f"{save_folder}/generations/{i}.png",
                        f"{generations_folder}/"
                        f"{i}.png")

    zip_file_path = (
        f"{save_folder}/"
        f"{args.model}_c{args.complexity}_tg{args.t_guidance}_"
        f"g{args.guidance}_chunk{chunk_index}.zip"
    )

    print(f"Zipping images to {zip_file_path}...")
    with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for files in os.listdir(save_folder + "generations/"):
            print(f"File is {files}")
            file_path = os.path.join(save_folder + "generations/", files)
            arcname = os.path.relpath(file_path, save_folder)
            zipf.write(file_path, arcname)
    print(f"Images zipped to {zip_file_path}")
    print("Copying zip file to generations folder...")
    shutil.copy(zip_file_path, generations_folder)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Generate Images using Stable Diffusion'
    )
    parser.add_argument('-m', '--model', help='SDModel',
                        default='comm', type=str, required=True,
                        choices=["SD15", "SDXL", "SD35L", "SD35M"])
    parser.add_argument('-n', '--num_jobs', help='num jobs',
                        default=50, type=int, required=False)
    parser.add_argument('-c', '--complexity', help='complexity level [0,4]',
                        default=1, type=int, required=True)
    parser.add_argument('-rd', '--random', help='random seed',
                        default=42, type=int, required=False)
    parser.add_argument('-g', '--guidance', help='guidance scale',
                        default=5.0, type=float, required=True)
    parser.add_argument('-tgpu', '--type_gpu', help='gputype only for SD35L',
                        default="high", type=str, required=False)
    parser.add_argument('-tg', '--t_guidance', help='type of guidance',
                        default="CFG", type=str, required=True,
                        choices=["CFG", "APG", "CADS", "Interval"])
    parser.add_argument('-s', '--scheduler', help='scheduler type',
                        default="EulerDiscrete", type=str, required=False)

    args = parser.parse_args()
    print(args)

    num_batch = {'SD35L': 20,
                 'SD35M': 20,
                 'SDXL': 20,
                 'SD15': 20}
    time = {"SD35L": 60*70,
            "SD35M": 60*12,
            "SDXL": 60*6,
            "SD15": 60*6}

    with open(
        f"{DATA_PATH}/full_dict_gemma3_eval_clean_siglip_5k_4caps.pkl", "rb"
    ) as f:
        img_caps_pairs = pickle.load(f)

    caption_list = []
    for i in range(0, 100000, 20):
        key = i // 20
        caption = img_caps_pairs[i]['caps'][args.complexity]
        caption_list.append((key, caption))
    print(caption_list[:10])

    assert len(caption_list) == 5000

    chunk_size = 5000 // args.num_jobs
    chunks = []
    for x in range(0, len(caption_list), chunk_size):
        chunk = caption_list[x:x + chunk_size]
        chunks.append(chunk)
    num_jobs = len(chunks)

    try:
        checkfolder = (
            f"SDinference_gemma3_siglip_clip_"
            f"{args.scheduler}_{args.t_guidance}/"
            f"guidance_{args.guidance:.01f}/"
            f"model_{args.model}/"
            f"complexity_{args.complexity}/"
        )
        files = os.listdir(f"{SAVE_PATH}/{checkfolder}")
        generated_chunk_ids = []
        for file in files:
            if file.endswith(".zip"):
                chunk_id = file.split("_chunk")[-1].split(".")[0]
                generated_chunk_ids.append(int(chunk_id))
    except FileNotFoundError:
        print("No previous generations found, starting fresh.")
        generated_chunk_ids = []
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        generated_chunk_ids = []
        exit()

    executor = submitit.AutoExecutor(folder="../logs/SDinference_jobs/")
    executor.update_parameters(
        timeout_min=int(time[args.model]),
        mem_gb=100,
        name=f"c{args.complexity}/{args.model}/g{args.guidance}",
        slurm_array_parallelism=1,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=4,
        slurm_partition="",
    )
    with executor.batch():
        for i in range(num_jobs):
            if i in generated_chunk_ids:
                print(f"Chunk {i} already generated, skipping...")
                continue
            selected_captions = chunks[i]
            try:
                assert len(selected_captions) > 0
            except AssertionError:
                print(f"No clusters needed to generate in chunk {i}")
                continue
            except Exception as e:
                print(f"An unexpected error occurred: {e}")

            save_folder = (
                f"{SAVE_PATH}/"
                f"SDinference_gemma3_siglip_clip_"
                f"{args.scheduler}_{args.t_guidance}/"
                f"guidance_{args.guidance:.01f}"
                f"/model_{args.model}/"
                f"complexity_{args.complexity}/"
            )
            os.makedirs(save_folder, exist_ok=True)
            selected_key = [key for key, _ in selected_captions]
            print(f"{i} : {selected_key}")
            job = executor.submit(inference,
                                  args,
                                  selected_captions,
                                  num_batch[args.model],
                                  i)
            print(job)
