
import os
import pickle
import copy
import random
import argparse
import itertools
import numpy as np
from tqdm import tqdm
from vendi_score import vendi

DATA_PATH = '../../metadata/cc12m/'
MIN_SIZE = 20
SAVE_PATH = "../../outputs/evaluations/diversity/"

constraint = {"SD35L": "volta32gb",
              "SD35M": "volta32gb",
              "SDXL": "volta32gb",
              "SD15": "volta32gb"}

partition = {"SD35L": "learnlab",
             "SD35M": "learnlab",
             "SDXL": "learnlab",
             "SD15": "learnlab"}

time = {"SD35L": 60*48,
        "SD35M": 60*48,
        "SDXL": 60*48,
        "SD15": 60*48}


def get_embedding_from_index(total_img_index, loaded_embedding):
    embedding_list = []
    for img_indexes in total_img_index:
        assert len(img_indexes) == 20
        selected_embeddings = loaded_embedding[img_indexes]
        embedding_list.append(selected_embeddings)
    return embedding_list


def vendi_cal(X, q):
    if len(X.shape) == 1:
        X = X.reshape((1, -1))
    n, d = X.shape
    if n < d:
        return vendi.score_X(X, q)
    return vendi.score_dual(X, q)


def get_vendi_score(args, selected_clusters, chunk):
    print(args)

    print("Loading embeddings...")
    embed_list = []
    for i in range(20):
        embed = np.load(f"{DATA_PATH}/CC12M_vendiscore_clip_full/"
                        f"embedding_imgs_{args.extractor}_{i}.npy")
        embed_list.append(embed)
    loaded_embedding = np.concatenate(embed_list, axis=0)
    assert loaded_embedding.shape[0] == 978894
    print("Embeddings loaded")

    print("Start Resampling...")
    total_embed_set = []
    for pi, centroid in enumerate(tqdm(selected_clusters)):
        resampled_img_index = []
        cluster_img_index = np.array(selected_clusters[centroid])
        if args.nonresample:
            resampled_img_index.append(cluster_img_index)
        else:
            assert len(cluster_img_index) >= MIN_SIZE
            resample_time = (len(cluster_img_index)) // MIN_SIZE
            for _ in range(5):
                np.random.shuffle(cluster_img_index)
                end_idx = MIN_SIZE * resample_time
                img_index_to_chunk = cluster_img_index[:end_idx]
                assert len(img_index_to_chunk) == resample_time * MIN_SIZE, (
                    f"len {len(img_index_to_chunk)}, "
                    f"resample {resample_time}"
                )
                # resampled_img_index is a list of clusters
                clusters = np.split(img_index_to_chunk, resample_time)
                resampled_img_index.extend(clusters)
        # this is a list of cluster embedding of size 20 * embed_size
        resampled_embed_set = get_embedding_from_index(
            resampled_img_index, loaded_embedding
        )
        total_embed_set.append(resampled_embed_set) 
        # resampled_embed_set is a list of list of array.
        # The inner most array is the 20 * embed_size cluster embedding,
        # then the list is per cluster, then the outer list is per caption
    print("Start Vendi Score Calculation...")
    inception_vs = []
    for caption_clusters in tqdm(total_embed_set):
        caption_vs = []
        for embeds in caption_clusters:
            caption_vs.append(vendi_cal(embeds, args.q))
        inception_vs.append(caption_vs)
    print(inception_vs)
    assert len(inception_vs) == len(selected_clusters)

    savefolder = f"{args.savepath}/"

    save_path = (
        f"{SAVE_PATH}/"
        f"{savefolder}/"
        f"vsdata_q_{args.q:.1f}_compl_{args.complexity}"
        f"_ext_{args.extractor}_chunk_{chunk}.pkl"
    )
    with open(save_path, "wb") as f:
        pickle.dump(inception_vs, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Generate Images using Stable Diffusion'
    )
    parser.add_argument('-c', '--complexity', help='complexity level 1, 2, 3',
                        default=1, type=int, required=True)
    parser.add_argument('-ex', '--extractor', help='feature_extractor',
                        default='dreamsim', type=str, required=False)
    parser.add_argument('-m', '--model', help='SD model',
                        default='SD35L', type=str, required=False)
    parser.add_argument('-g', '--guidance', help='guidance scale',
                        default=3, type=float, required=False)
    parser.add_argument('-q', '--q', help='vendi score scale',
                        default="1", type=str, required=False)
    parser.add_argument('-n', '--num_jobs', help='num_jobs',
                        default=1, type=int, required=False)
    parser.add_argument('-rd', '--random', help='random seed',
                        default=42, type=int, required=False)
    parser.add_argument('-s', '--scheduler', help='scheduler type',
                        default=None, type=str, required=False)
    parser.add_argument('-dedup', '--deduplicate', help='whether deduplicate',
                        action="store_true")
    parser.add_argument('-sp', '--savepath', help='folder to save vendi score',
                        default="data_vendiscore", type=str, required=False)

    args = parser.parse_args()
    print(args)

    random.seed(args.random)

    args.q = float(args.q)
    args.savepath = args.savepath

    os.makedirs(f"{SAVE_PATH}/{args.savepath}", exist_ok=True)

    # complexity_to_check = [1, 2, 3]
    extractor_to_check = ['dino', 'inception']
    ListOList = [extractor_to_check]

    combination_args = list(itertools.product(*ListOList))

    experimental_args = []
    for i in range(len(combination_args)):
        if args.dataset:
            args.extractor, = combination_args[i]
        else:
            (
                args.model, args.guidance, args.extractor, args.scheduler
            ) = combination_args[i]
        experimental_args.append(copy.deepcopy(args))

    centroid_path = (
        f"{DATA_PATH}/siglip_clusters/eval_set_c{args.complexity}.pkl"
    )
    with open(centroid_path, 'rb') as f:
        data = pickle.load(f)

    # Add deduplicate here
    centroid_dict = {}
    captions = []
    for key, value in data.items():
        if args.deduplicate:
            if value["caption"] not in captions:
                captions.append(value["caption"])
                centroid_dict[key] = value["image_ids"]
        else:
            captions.append(value["caption"])
            centroid_dict[key] = value["image_ids"]

    for i in range(len(combination_args)):
        get_vendi_score(experimental_args[i], centroid_dict, 0)
