import numpy as np
import pickle
import random
from tqdm import tqdm
import os
from collections import defaultdict


DATA_PATH = "../metadata/cc12m/"


def get_threshold(num_complexity_levels=4):
    img_embed = np.load(
            os.path.join(
                DATA_PATH,
                "clip_embeddings",
                "ViT-SO400M-14-SigLIP-384_webli_img.npy"
            )
        )
    threshold_list = []
    for complexity in range(num_complexity_levels):
        txt_embed = np.load(
            os.path.join(
                DATA_PATH,
                "clip_embeddings",
                f"ViT-SO400M-14-SigLIP-384_webli_text_c{complexity}.npy"
            )
        )

        similarity = np.einsum('ij,ij->i', img_embed, txt_embed)
        similarity = similarity.astype(np.float32)

        percentile = np.percentile(similarity, [77])[0]

        threshold_list.append(percentile)

    return threshold_list


def filter_cluster_size_too_small(
        data, complexity, distances, labels, threshold_list):
    # Create a list of index-value pairs
    indexed_data = list(enumerate(data))

    sampled_data_with_index = []

    for index, caption in tqdm(indexed_data):
        neighbors_cap = distances[index] > threshold_list[complexity]
        if len(labels[index][neighbors_cap]) > 20:
            sampled_data_with_index.append((caption, index))

    return sampled_data_with_index


def filter_captions(data, complexity, threshold_list):
    print("Loading distances and labels...")
    with open(
        f"{DATA_PATH}/clip_embeddings/"
        f"clustering_clip_dbscan_IP/"
        f"ViT-SO400M-14-SigLIP-384_webli_distances_c{complexity}.pkl",
        "rb"
    ) as f:
        distances = pickle.load(f)

    with open(
        f"{DATA_PATH}/clip_embeddings/"
        f"clustering_clip_dbscan_IP/"
        f"ViT-SO400M-14-SigLIP-384_webli_labels_c{complexity}.pkl",
        "rb"
    ) as f:
        labels = pickle.load(f)

    selected_captions = filter_cluster_size_too_small(
        data, complexity, distances, labels, threshold_list
    )

    eval_set = defaultdict()
    for enumind, (caption, index) in enumerate(selected_captions):
        neighbors = distances[index] > threshold_list[complexity]
        image_ids = labels[index][neighbors]
        eval_set[enumind] = {"caption": caption, "image_ids": image_ids}

    return eval_set


def get_common_images(eval_set_list):
    # Get the common images across caption lengths
    image_ids_list = []
    for eval_set in eval_set_list:
        image_ids_set = set()
        for _, metadata in eval_set.items():
            image_ids_set.update(metadata["image_ids"])
        image_ids_list.append(image_ids_set)
    common_image_ids = set.intersection(*image_ids_list)
    return common_image_ids


def align_eval_sets(eval_set_list):
    common_image_ids = get_common_images(eval_set_list)
    # Filter out the unique images and keep clusters larger than 20
    new_eval_set_list = []
    for eval_set in eval_set_list:
        new_eval_set = defaultdict(dict)
        for index, metadata in tqdm(eval_set.items()):
            image_ids_set = set(metadata["image_ids"])
            image_ids_val = list(image_ids_set.intersection(common_image_ids))
            if len(image_ids_val) >= 20:
                metadata["image_ids"] = image_ids_val
                new_eval_set[index] = metadata
        new_eval_set_list.append(new_eval_set)
    new_common_image_ids = get_common_images(new_eval_set_list)
    is_same_length = len(common_image_ids) == len(new_common_image_ids)
    return new_eval_set_list, is_same_length


def random_sample_eval_set(eval_set, sample_size):
    sampled_eval_set = random.sample(list(eval_set.items()), sample_size)
    sampled_eval_set = {
        i: metadata
        for i, (index, metadata) in enumerate(sampled_eval_set)
    }
    return sampled_eval_set


def align_complexities(data, num_complexity_levels):
    rnd_seed = 42
    random.seed(rnd_seed)
    np.random.seed(rnd_seed)

    threshold_list = get_threshold(num_complexity_levels)

    eval_set_list = []
    for complexity in range(4):
        eval_set = filter_captions(data, complexity, threshold_list)
        eval_set_list.append(eval_set)

    # Iteratively align the eval sets
    break_signal = False
    new_eval_set_list = eval_set_list
    cnt = 0
    while break_signal is False:
        print(f"cnt: {cnt}")
        new_eval_set_list, break_signal = align_eval_sets(new_eval_set_list)
        cnt += 1

    # sample 5000 captions from aligned eval sets
    sampled_eval_set_list = []
    for eval_set in new_eval_set_list:
        sampled_eval_set = random_sample_eval_set(eval_set, 5000)
        sampled_eval_set_list.append(sampled_eval_set)

    return sampled_eval_set_list


def main():
    with open(f"{DATA_PATH}/full_dict_gemma3_eval_clean_4caps.pkl", "rb") as f:
        data = pickle.load(f)
    
    image_names = list(data.keys())

    sampled_eval_set_list = align_complexities(data, 4)

    eval_dict_list = []
    for i in range(4):
        eval_dict = defaultdict(dict)
        for index, (_, item) in enumerate(sampled_eval_set_list[i].items()):
            for k in range(20):
                eval_dict[index * 20 + k] = {
                    "image_name": f"{index * 20 + k}.jpg",
                    "caps": item["caption"]
                }
        eval_dict_list.append(eval_dict)

    save_eval_set = {
        index: {
            "image_name": eval_dict_list[i][index]["image_name"],
            "caps": [eval_dict_list[i][index]["caps"] for i in range(4)]
        } for index in eval_dict_list[0].keys()
    }

    with open(
        f"{DATA_PATH}/full_dict_gemma3_eval_clean_siglip_5k_4caps.pkl", "wb"
    ) as f:
        pickle.dump(save_eval_set, f)

    unique_image_ids_list = []
    for complexity in range(4):
        unique_image_ids = set()
        for index, item in sampled_eval_set_list[complexity].items():
            for id in item["image_ids"]:
                unique_image_ids.add(id)
        print(len(unique_image_ids))
        unique_image_ids_list.append(unique_image_ids)

    real_eval_set_list = []
    for complexity in range(4):
        real_eval_set = {}
        for id in unique_image_ids_list[complexity]:
            image_name = image_names[id]
            caps = data[image_name]["caps"][complexity]
            real_eval_set[image_name] = {"caps": caps}
        real_eval_set_list.append(real_eval_set)

    with open(f"{DATA_PATH}/"
              f"full_dict_gemma3_eval_clean_siglip_real_5k_4caps.pkl",
              "wb") as f:
        pickle.dump(real_eval_set_list, f)

if __name__ == "__main__":
    main()
