import os
import pickle
import copy
from PIL import Image
import numpy as np
import torch
from transformers import AutoModel
import random
import argparse
import itertools
from tqdm import tqdm
import submitit
import shutil
import zipfile
from vendi_score import vendi
from get_gen_embeddings import compute_embeddings

SAVE_PATH = '../../outputs/evaluations/diversity/'
MIN_CLUSTER_SIZE = 20
batch_size = 4096

constraint = {"SD35L": "volta32gb",
              "SD35M": "volta32gb",
              "SDXL": "volta32gb",
              "SD15": "volta32gb"}

partition = {"SD35L": "learnfair",
             "SD35M": "learnfair",
             "SDXL": "learnfair",
             "SD15": "learnfair"}

time = {"SD35L": 60*12,
        "SD35M": 60*12,
        "SDXL": 60*12,
        "SD15": 60*12}


def vendi_cal(X, q):
    if len(X.shape) == 1:
        X = X.reshape((1, -1))
    n, d = X.shape
    if n < d:
        return vendi.score_X(X, q)
    return vendi.score_dual(X, q)


def get_embeddings(args):
    if args.extractor == "dino":
        feature_extractor = AutoModel.from_pretrained('facebook/dinov2-base')
    elif args.extractor == "inception":
        feature_extractor = None
    else:
        raise ValueError(f"Unknown feature extractor: {args.extractor}")

    images = []
    gen_folder = (
                    f"/tmp/job{os.environ['SLURM_JOB_ID']}/"
                    f"generations/"
                )

    for idx in range(100000):
        img = Image.open(f"{gen_folder}/{idx}.png").convert("RGB")
        img = img.resize((256, 256), Image.LANCZOS)
        images.append(img)
    print(f"total images: {len(images)}")
    assert len(images) == 100000, "should have 10000 images"

    loaded_embeddings = compute_embeddings(
        images,
        model=feature_extractor,
        transform=None,
        batch_size=batch_size,
        device=torch.device("cuda"),
        extractor_name=args.extractor,
        cache=None,
        normalize=True,
    )

    assert len(loaded_embeddings) == len(images)

    return np.split(loaded_embeddings,
                    len(loaded_embeddings) // MIN_CLUSTER_SIZE)


def get_vendi_score(args, chunk):
    print(args)

    targeted_folder = f"/tmp/job{os.environ['SLURM_JOB_ID']}"

    assert "SD" in args.model, "model should be SD15, SD35L, SD35M or SDXL"
    copydata(args.tguidance, targeted_folder, args)

    total_embed_set = get_embeddings(args)
    assert len(total_embed_set) == 5000

    inception_vs = [vendi_cal(embeds, args.q) for embeds in total_embed_set]

    print(inception_vs)
    assert len(inception_vs) == len(total_embed_set)

    savefolder = f"{args.savepath}_{args.tguidance}/"
    os.makedirs(f"{SAVE_PATH}/{savefolder}", exist_ok=True)
    save_path = (
        f"{SAVE_PATH}/"
        f"{savefolder}/"
        f"vs_q_{args.q:.1f}_{args.model}_compl_{args.complexity}"
        f"_sample_{args.scheduler}"
        f"_guid_{args.guidance}_ext_{args.extractor}_chunk_{chunk}.pkl"
    )
    with open(save_path, "wb") as f:
        pickle.dump(inception_vs, f)


def copydata(tguidances, targeted_folder, args):
    os.makedirs(targeted_folder, exist_ok=False)
    assert "SD" in args.model, "model should be SD15, SD35L, SD35M or SDXL"
    # change for how many chunks used when generating images
    chunk_size = {
        "SD15": 1,
        "SDXL": 5,
        "SD35M": 25,
        "SD35L": 50,
    }
    img_zip_path = (
        f"../../outputs/SDinference/"
        f"SDinference_gemma3_siglip_clip_EulerDiscrete_{tguidances}/"
        f"guidance_{args.guidance:.01f}/"
        f"model_{args.model}/"
        f"complexity_{args.complexity}/"
    )
    chunk_ids = list(range(chunk_size[args.model]))
    random.shuffle(chunk_ids)
    for chunk_id in tqdm(chunk_ids):
        shutil.copy(
            f"{img_zip_path}/{args.model}_c{args.complexity}_"
            f"tg{tguidances}_g{args.guidance:.01f}_chunk{chunk_id}.zip",
            targeted_folder
        )
        with zipfile.ZipFile(
            f"{targeted_folder}/{args.model}_c{args.complexity}_"
            f"tg{tguidances}_g{args.guidance:.01f}_chunk{chunk_id}.zip",
            'r'
        ) as zip_ref:
            zip_ref.extractall(targeted_folder)
    print("DATA LOADED")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Generate Images using Stable Diffusion'
    )
    parser.add_argument('-c', '--complexity', help='complexity level 1, 2, 3',
                        default=1, type=int, required=True)
    parser.add_argument('-ex', '--extractor', help='feature_extractor',
                        default='dreamsim', type=str, required=False)
    parser.add_argument('-m', '--model', help='SD model',
                        default='SD35L', type=str, required=False)
    parser.add_argument('-g', '--guidance', help='guidance scale',
                        default=3, type=float, required=False)
    parser.add_argument('-q', '--q', help='vendi score scale',
                        default="1", type=str, required=False)
    parser.add_argument('-n', '--num_jobs', help='num_jobs',
                        default=1, type=int, required=False)
    parser.add_argument('-rd', '--random', help='random seed',
                        default=42, type=int, required=False)
    parser.add_argument('-s', '--scheduler', help='scheduler type',
                        default=None, type=str, required=False)
    parser.add_argument('-tguidance', '--tguidance', help='guidance method',
                        default="100K", type=str, required=False)
    parser.add_argument('-sp', '--savepath', help='folder to save vendi score',
                        default="SD_vendiscore", type=str, required=False)

    args = parser.parse_args()
    print(args)

    random.seed(args.random)

    args.q = float(args.q)

    models_to_check = ["SDXL", "SD15", "SD35M", "SD35L"]
    guidance_to_check = [3., 5., 7., 9., 13.]
    sampler_to_check = ["EulerDiscrete"]
    extractor_to_check = ["inception", "dino"]
    tguidance_to_check = ["CFG", "APG", "Interval", "CADS"]

    ListOList = [models_to_check,
                 guidance_to_check,
                 extractor_to_check,
                 sampler_to_check,
                 tguidance_to_check]
    combination_args = list(itertools.product(*ListOList))

    experimental_args = []
    for i in range(len(combination_args)):
        (
            args.model,
            args.guidance,
            args.extractor,
            args.scheduler,
            args.tguidance,
        ) = combination_args[i]
        experimental_args.append(copy.deepcopy(args))

    executor = submitit.AutoExecutor(folder="../../outputs/vendi_logs/")
    executor.update_parameters(
        timeout_min=int(60*3),
        mem_gb=150,
        name="vendical",
        slurm_array_parallelism=1,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=2,
        slurm_partition="",
    )
    with executor.batch():
        for i in range(len(combination_args)):
            selected_clusters = None

            print(f"{i} : {experimental_args[i]}")
            job = executor.submit(get_vendi_score,
                                  experimental_args[i],
                                  0)
            print(job)
