import os
import numpy as np
from copy import deepcopy
import argparse
import pandas as pd
import json
from tqdm import tqdm
import faiss

import torch
import clip
from torch.nn import functional as F

from imagenet_classes import IMAGENET_CLASS_LABELS, IMAGENET_21K_CLASS_LABELS


# region: setup

def set_seed(seed):
    """for reproducibility
    :param seed: seed value
    :return:
    """
    np.random.seed(seed)
    # random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--dataset", type=str, required=True, choices=['SD', 'COCO', 'ImageNet', 'DeadLeaves'])
args = parser.parse_args()

set_seed(args.seed)

save_dir = f'experiments/retrieval/{args.dataset}/probelog'
os.makedirs(save_dir, exist_ok=True)
os.makedirs(f"{save_dir}/errors", exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model_name = "ViT-B/32"

args.hf2hf_label_mapping_path = 'HF2HF_label_mapping.json'

args.hub_reps_dir = f'experiments/datasets/INet_Hub_reprs__{args.dataset}_probes'
args.hf_reps_dir = f'experiments/datasets/Big_HF_Hub_reprs__{args.dataset}_probes'
args.clip_reps_dir = f'experiments/datasets/CLIP_image_reprs__{args.dataset}_probes'

clip_representations_path = f"{args.clip_reps_dir}/image_representations.pt"
prompt_template = "a photo of a {}"

print('Args:')
for k, v in vars(args).items():
    print(f'\t{k}:{v}')
print(f'\tExp. Name: {save_dir.split("/")[-1]}')

# endregion: setup

# region: Load representations and metadata

hub_reps_path = os.path.join(args.hub_reps_dir, "logit_reprs.npy")
metadata_path = os.path.join(args.hub_reps_dir, "metadata.csv")

print("Loading data: representations and metadata...")
reps = np.load(hub_reps_path, allow_pickle=True)
metadata = pd.read_csv(metadata_path)

# remove rows containing nans
rows_to_keep = ~np.isnan(reps).any(axis=1)
metadata = metadata.iloc[rows_to_keep, :]
reps = reps[rows_to_keep]
reps = reps.astype(np.float32)

labels_class_name = metadata["class_name"].values

hf_reps_path = os.path.join(args.hf_reps_dir, "logit_reprs.npy")
hf_metadata_path = os.path.join(args.hf_reps_dir, "metadata.csv")

# load hf data
hf_reps = np.load(hf_reps_path, allow_pickle=True)
hf_reps = hf_reps.astype(np.float32)

hf_metadata = pd.read_csv(hf_metadata_path)
hf_labels_class_name = hf_metadata["class_name"].values

clip_model, _ = clip.load(clip_model_name, device=device)
clip_image_representations = torch.load(clip_representations_path).to(device)

# region: Label mappings

inet2inet_label_mapping_dict = {k.lower(): k.lower() for k in IMAGENET_CLASS_LABELS}
hf2hf_label_mapping_dict = json.load(open(args.hf2hf_label_mapping_path, "r"))

# Remove: deletes all logits with no matches
all_possible_labels = []
for k, vs in hf2hf_label_mapping_dict.items():
    for v in vs:
        if v not in all_possible_labels:
            all_possible_labels.append(v)
inds_in_labels = np.where(hf_metadata["class_name"].isin(all_possible_labels).values)[0]
hf_metadata = hf_metadata.iloc[inds_in_labels, :]
hf_reps = hf_reps[inds_in_labels, :]
hf_metadata = hf_metadata.reset_index(drop=True)

all_hf_classes = list(np.unique(hf_metadata['class_name'].values))


# endregion: Label mappings

# endregion: Load representations and metadata

# region: Funcs


def get_retrieval_results__inet21k_negs(normed_hub_reps, normed_query_reps, query_metadata, hub_metadata, negs_reprs,
                                        label_map, n_probes=10, max_rank=50, verbose=True):
    results = {'q_ind': [], 'n_classes': [], 'class_name': []}
    count_non_existing = 0
    for i in range(max_rank):
        results[f'score_top_{i + 1}'] = []
        results[f'is_match_{i + 1}'] = []
        results[f'class_{i + 1}'] = []
        results[f'index_{i + 1}'] = []

    dist_matrix = np.zeros((normed_query_reps.shape[0], normed_hub_reps.shape[0]), dtype=np.float32)
    negs_dist_matrix = np.zeros((negs_reprs.shape[0], normed_hub_reps.shape[0]), dtype=np.float32)

    pbar = range(len(normed_hub_reps))
    if verbose:
        pbar = tqdm(pbar)

    for i in pbar:
        # find the top k indices of the logit representation
        logit_repr = normed_hub_reps[i][np.newaxis, :]
        order = logit_repr[0].argsort()
        selected_probes = order[-n_probes:]
        logit_repr = logit_repr[:, selected_probes]
        cur_normed_query_reps = normed_query_reps[:, selected_probes]
        cur_negs_reprs = negs_reprs[:, selected_probes]
        cur_dists = torch.cdist(torch.from_numpy(logit_repr), torch.from_numpy(cur_normed_query_reps))[0].numpy()
        cur_dists_negs = torch.cdist(torch.from_numpy(logit_repr), torch.from_numpy(cur_negs_reprs))[0].numpy()
        dist_matrix[:, i] = cur_dists
        negs_dist_matrix[:, i] = cur_dists_negs

    for j in range(len(normed_hub_reps)):
        retrieval_inds = negs_dist_matrix[:, j].argsort()[:50]
        dist_matrix[:, j] = dist_matrix[:, j] - np.mean(negs_dist_matrix[retrieval_inds, j])

    for i in range(len(normed_query_reps)):

        query_cls_name = query_metadata["class_name"].values[i].lower()
        if query_cls_name.startswith(' '):
            query_cls_name = query_cls_name[1:]
        if query_cls_name not in label_map.keys():
            label_map[query_cls_name] = []
            count_non_existing += 1
            if ',' not in query_cls_name:
                print(f"Class {query_cls_name} not found in label map. Count: {count_non_existing}")
        label_names = label_map[query_cls_name]
        if len(label_names) == 0:
            continue
        n_classes = query_metadata['n_classes'].values[i]
        #     if n_classes <= 15:
        #         continue
        results['q_ind'].append(i)
        results['n_classes'].append(n_classes)
        results['class_name'].append(query_metadata["class_name"].values[i])

        # use the distances to find the hub representations with the smallest distances from the query
        retrieval_inds = dist_matrix[i, :].argsort()[:max_rank]
        retrieval_dists = dist_matrix[i, retrieval_inds]

        for j in range(len(retrieval_inds)):
            retrieved_index = retrieval_inds[j]
            d = retrieval_dists[j]
            results[f'score_top_{j + 1}'].append(d)
            #             retrieval_class = labels_class_name[retrieved_index]
            retrieval_class = hub_metadata["class_name"].values[retrieved_index].lower()
            if retrieval_class.startswith(' '):
                retrieval_class = retrieval_class[1:]
            is_match = retrieval_class in label_names
            results[f'is_match_{j + 1}'].append(is_match)
            results[f'class_{j + 1}'].append(retrieval_class)
            results[f'index_{j + 1}'].append(retrieved_index)

    results_df = pd.DataFrame(results)

    # Initialize lists to store precision and recall for each i
    avg_precision_at_k = []
    avg_precision_at_k__found = []
    top_k_acc = []
    avg_score__found = []
    std_scores__found = []
    avg_score__not_found = []
    std_scores__not_found = []
    avg_score = []
    std_scores = []

    # Iterate through each rank i (1 to max_rank)
    cur_cols = []
    cur_score_cols = []
    for i in range(max_rank):
        new_match_col = f'is_match_{i + 1}'
        new_score_col = f'score_top_{i + 1}'
        cur_cols.append(new_match_col)
        cur_score_cols.append(new_score_col)
        cur_matches = results_df[cur_cols].values
        cur_precisions = cur_matches.mean(1)
        results_df[f'precision_at_{i + 1}'] = cur_precisions
        cur_avg_precision = np.mean(cur_precisions)
        avg_precision_at_k.append(cur_avg_precision)

        cur_match_found = cur_matches.sum(1) > 0
        cur_match_not_found = cur_matches.sum(1) == 0
        cur_acc = np.mean(cur_match_found)
        top_k_acc.append(cur_acc)
        cur_avg_precision__found = np.mean(cur_precisions[cur_match_found])
        avg_precision_at_k__found.append(cur_avg_precision__found)

        cur_scores = results_df[cur_score_cols].values
        avg_score.append(np.mean(cur_scores.flatten()))
        std_scores.append(np.std(cur_scores.flatten()))
        avg_score__found.append(np.mean(cur_scores[cur_match_found].flatten()))
        std_scores__found.append(np.std(cur_scores[cur_match_found].flatten()))
        avg_score__not_found.append(np.mean(cur_scores[cur_match_not_found].flatten()))
        std_scores__not_found.append(np.std(cur_scores[cur_match_not_found].flatten()))

    # Convert precision and recall lists to a DataFrame for visualization
    metrics_df = pd.DataFrame({
        'Rank': range(1, max_rank + 1),
        'Avg. Precision @ k': avg_precision_at_k,
        'Top k Acc.': top_k_acc,
        'Avg. Precision @ k (Found)': avg_precision_at_k__found,

    })

    return metrics_df, results_df


def wpmi(clip_feats, target_feats, top_k=50, a=2, lam=0.6, device='cuda', min_prob=1e-7):
    target_feats = torch.from_numpy(target_feats).to(device)
    with torch.no_grad():
        torch.cuda.empty_cache()

        # clip_feats = torch.nn.functional.softmax(a * torch.from_numpy(clip_feats).to(device), dim=1)
        clip_feats = torch.from_numpy(clip_feats).to(device)
        clip_feats = torch.nn.functional.sigmoid(a * clip_feats)

        inds = torch.topk(target_feats, dim=0, k=top_k)[1]
        prob_d_given_e = []

        for orig_id in range(target_feats.shape[1]):
            torch.cuda.empty_cache()
            curr_clip_feats = clip_feats.gather(0, inds[:, orig_id:orig_id + 1].expand(-1, clip_feats.shape[1])).to(device)
            curr_p_d_given_e = torch.sum(torch.log(curr_clip_feats + min_prob), dim=0, keepdim=True)
            prob_d_given_e.append(curr_p_d_given_e)

        prob_d_given_e = torch.cat(prob_d_given_e, dim=0)
        # logsumexp trick to avoid underflow
        prob_d = (torch.logsumexp(prob_d_given_e, dim=0, keepdim=True) -
                  torch.log(prob_d_given_e.shape[0] * torch.ones([1]).to(device)))

        mutual_info = prob_d_given_e - lam * prob_d
    return mutual_info.detach().cpu().numpy()


def softwpmi(clip_feats, target_feats, top_k=50, a=10, lam=1, device='cuda',
             min_prob=1e-7, p_start=0.998, p_end=0.97):
    target_feats = torch.from_numpy(target_feats).to(device)
    with torch.no_grad():
        torch.cuda.empty_cache()
        # clip_feats = torch.nn.functional.softmax(a * torch.from_numpy(clip_feats).to(device), dim=1)
        clip_feats = torch.from_numpy(clip_feats).to(device)
        clip_feats = torch.nn.functional.sigmoid(a * clip_feats)

        inds = torch.topk(target_feats, dim=0, k=top_k)[1]
        prob_d_given_e = []

        p_in_examples = p_start - (torch.arange(start=0, end=top_k) / top_k * (p_start - p_end)).unsqueeze(1).to(device)
        for orig_id in range(target_feats.shape[1]):
            curr_clip_feats = clip_feats.gather(0, inds[:, orig_id:orig_id + 1].expand(-1, clip_feats.shape[1])).to(device)

            curr_p_d_given_e = 1 + p_in_examples * (curr_clip_feats - 1)
            curr_p_d_given_e = torch.sum(torch.log(curr_p_d_given_e + min_prob), dim=0, keepdim=True)
            prob_d_given_e.append(curr_p_d_given_e)
            torch.cuda.empty_cache()

        prob_d_given_e = torch.cat(prob_d_given_e, dim=0)
        # print(prob_d_given_e.shape)
        # logsumexp trick to avoid underflow
        prob_d = (torch.logsumexp(prob_d_given_e, dim=0, keepdim=True) -
                  torch.log(prob_d_given_e.shape[0] * torch.ones([1]).to(device)))
        mutual_info = prob_d_given_e - lam * prob_d
    return mutual_info.detach().cpu().numpy()


def get_retrieval_results__all_probes__inet21k_negs(normed_hub_reps, normed_query_reps, query_metadata, hub_metadata, negs_reprs,
                                                    label_map, max_rank=50, verbose=True):
    results = {'q_ind': [], 'n_classes': [], 'class_name': []}
    for i in range(max_rank):
        results[f'score_top_{i + 1}'] = []
        results[f'is_match_{i + 1}'] = []
        results[f'class_{i + 1}'] = []

    with torch.no_grad():
        dist_matrix = torch.cdist(torch.from_numpy(normed_query_reps), torch.from_numpy(normed_hub_reps)).numpy()
        negs_dist_matrix = torch.cdist(torch.from_numpy(negs_reprs), torch.from_numpy(normed_hub_reps)).numpy()

    for j in range(len(normed_hub_reps)):
        retrieval_inds = negs_dist_matrix[:, j].argsort()[:50]
        dist_matrix[:, j] = dist_matrix[:, j] - np.mean(negs_dist_matrix[retrieval_inds, j])

    pbar = range(len(normed_query_reps))
    if verbose:
        pbar = tqdm(pbar)

    for i in pbar:

        query_cls_name = query_metadata["class_name"].values[i].lower()
        if query_cls_name.startswith(' '):
            query_cls_name = query_cls_name[1:]
        label_names = label_map[query_cls_name]
        if len(label_names) == 0:
            continue
        n_classes = query_metadata['n_classes'].values[i]
        results['q_ind'].append(i)
        results['n_classes'].append(n_classes)
        results['class_name'].append(query_metadata["class_name"].values[i])

        # use the distances to find the hub representations with the smallest distances from the query
        retrieval_inds = dist_matrix[i, :].argsort()[:max_rank]
        retrieval_dists = dist_matrix[i, retrieval_inds]

        for j in range(len(retrieval_inds)):
            retrieved_index = retrieval_inds[j]
            d = retrieval_dists[j]
            results[f'score_top_{j + 1}'].append(d)
            #             retrieval_class = labels_class_name[retrieved_index]
            retrieval_class = hub_metadata["class_name"].values[retrieved_index].lower()
            if retrieval_class.startswith(' '):
                retrieval_class = retrieval_class[1:]
            is_match = retrieval_class in label_names
            results[f'is_match_{j + 1}'].append(is_match)
            results[f'class_{j + 1}'].append(retrieval_class)

    results_df = pd.DataFrame(results)

    # Initialize lists to store precision and recall for each i
    avg_precision_at_k = []
    avg_precision_at_k__found = []
    top_k_acc = []
    avg_score__found = []
    std_scores__found = []
    avg_score__not_found = []
    std_scores__not_found = []
    avg_score = []
    std_scores = []

    # Iterate through each rank i (1 to max_rank)
    cur_cols = []
    cur_score_cols = []
    for i in range(max_rank):
        new_match_col = f'is_match_{i + 1}'
        new_score_col = f'score_top_{i + 1}'
        cur_cols.append(new_match_col)
        cur_score_cols.append(new_score_col)
        cur_matches = results_df[cur_cols].values
        cur_precisions = cur_matches.mean(1)
        results_df[f'precision_at_{i + 1}'] = cur_precisions
        cur_avg_precision = np.mean(cur_precisions)
        avg_precision_at_k.append(cur_avg_precision)

        cur_match_found = cur_matches.sum(1) > 0
        cur_match_not_found = cur_matches.sum(1) == 0
        cur_acc = np.mean(cur_match_found)
        top_k_acc.append(cur_acc)
        cur_avg_precision__found = np.mean(cur_precisions[cur_match_found])
        avg_precision_at_k__found.append(cur_avg_precision__found)

        cur_scores = results_df[cur_score_cols].values
        avg_score.append(np.mean(cur_scores.flatten()))
        std_scores.append(np.std(cur_scores.flatten()))
        avg_score__found.append(np.mean(cur_scores[cur_match_found].flatten()))
        std_scores__found.append(np.std(cur_scores[cur_match_found].flatten()))
        avg_score__not_found.append(np.mean(cur_scores[cur_match_not_found].flatten()))
        std_scores__not_found.append(np.std(cur_scores[cur_match_not_found].flatten()))

    # Convert precision and recall lists to a DataFrame for visualization
    metrics_df = pd.DataFrame({
        'Rank': range(1, max_rank + 1),
        'Avg. Precision @ k': avg_precision_at_k,
        'Top k Acc.': top_k_acc,
        'Avg. Precision @ k (Found)': avg_precision_at_k__found,

    })

    return metrics_df, results_df


def get_retrieval_results__softwpmi(normed_hub_reps, normed_query_reps, query_metadata, hub_metadata, negs_reprs,
                                    label_map, n_probes=10, max_rank=50, verbose=True):
    results = {'q_ind': [], 'n_classes': [], 'class_name': []}
    count_non_existing = 0
    for i in range(max_rank):
        results[f'score_top_{i + 1}'] = []
        results[f'is_match_{i + 1}'] = []
        results[f'class_{i + 1}'] = []
        results[f'index_{i + 1}'] = []

    mi_matrix = softwpmi(normed_query_reps.T, normed_hub_reps.T, top_k=n_probes).T

    for i in range(len(normed_query_reps)):

        query_cls_name = query_metadata["class_name"].values[i].lower()
        if query_cls_name.startswith(' '):
            query_cls_name = query_cls_name[1:]
        if query_cls_name not in label_map.keys():
            label_map[query_cls_name] = []
            count_non_existing += 1
            if ',' not in query_cls_name:
                print(f"Class {query_cls_name} not found in label map. Count: {count_non_existing}")
        label_names = label_map[query_cls_name]
        if len(label_names) == 0:
            continue
        n_classes = query_metadata['n_classes'].values[i]
        #     if n_classes <= 15:
        #         continue
        results['q_ind'].append(i)
        results['n_classes'].append(n_classes)
        results['class_name'].append(query_metadata["class_name"].values[i])

        # use the distances to find the hub representations with the smallest distances from the query
        retrieval_inds = mi_matrix[i, :].argsort()[-max_rank:][::-1]
        retrieval_dists = mi_matrix[i, retrieval_inds]

        for j in range(len(retrieval_inds)):
            retrieved_index = retrieval_inds[j]
            d = retrieval_dists[j]
            results[f'score_top_{j + 1}'].append(d)
            #             retrieval_class = labels_class_name[retrieved_index]
            retrieval_class = hub_metadata["class_name"].values[retrieved_index].lower()
            if retrieval_class.startswith(' '):
                retrieval_class = retrieval_class[1:]
            is_match = retrieval_class in label_names
            results[f'is_match_{j + 1}'].append(is_match)
            results[f'class_{j + 1}'].append(retrieval_class)
            results[f'index_{j + 1}'].append(retrieved_index)

    results_df = pd.DataFrame(results)

    # Initialize lists to store precision and recall for each i
    avg_precision_at_k = []
    avg_precision_at_k__found = []
    top_k_acc = []
    avg_score__found = []
    std_scores__found = []
    avg_score__not_found = []
    std_scores__not_found = []
    avg_score = []
    std_scores = []

    # Iterate through each rank i (1 to max_rank)
    cur_cols = []
    cur_score_cols = []
    for i in range(max_rank):
        new_match_col = f'is_match_{i + 1}'
        new_score_col = f'score_top_{i + 1}'
        cur_cols.append(new_match_col)
        cur_score_cols.append(new_score_col)
        cur_matches = results_df[cur_cols].values
        cur_precisions = cur_matches.mean(1)
        results_df[f'precision_at_{i + 1}'] = cur_precisions
        cur_avg_precision = np.mean(cur_precisions)
        avg_precision_at_k.append(cur_avg_precision)

        cur_match_found = cur_matches.sum(1) > 0
        cur_match_not_found = cur_matches.sum(1) == 0
        cur_acc = np.mean(cur_match_found)
        top_k_acc.append(cur_acc)
        cur_avg_precision__found = np.mean(cur_precisions[cur_match_found])
        avg_precision_at_k__found.append(cur_avg_precision__found)

        cur_scores = results_df[cur_score_cols].values
        avg_score.append(np.mean(cur_scores.flatten()))
        std_scores.append(np.std(cur_scores.flatten()))
        avg_score__found.append(np.mean(cur_scores[cur_match_found].flatten()))
        std_scores__found.append(np.std(cur_scores[cur_match_found].flatten()))
        avg_score__not_found.append(np.mean(cur_scores[cur_match_not_found].flatten()))
        std_scores__not_found.append(np.std(cur_scores[cur_match_not_found].flatten()))

    # Convert precision and recall lists to a DataFrame for visualization
    metrics_df = pd.DataFrame({
        'Rank': range(1, max_rank + 1),
        'Avg. Precision @ k': avg_precision_at_k,
        'Top k Acc.': top_k_acc,
        'Avg. Precision @ k (Found)': avg_precision_at_k__found,

    })

    return metrics_df, results_df


def get_retrieval_results__wpmi(normed_hub_reps, normed_query_reps, query_metadata, hub_metadata, negs_reprs,
                                label_map, n_probes=10, max_rank=50, verbose=True):
    results = {'q_ind': [], 'n_classes': [], 'class_name': []}
    count_non_existing = 0
    for i in range(max_rank):
        results[f'score_top_{i + 1}'] = []
        results[f'is_match_{i + 1}'] = []
        results[f'class_{i + 1}'] = []

    mi_matrix = wpmi(normed_query_reps.T, normed_hub_reps.T, top_k=n_probes).T

    for i in range(len(normed_query_reps)):

        query_cls_name = query_metadata["class_name"].values[i].lower()
        if query_cls_name.startswith(' '):
            query_cls_name = query_cls_name[1:]
        if query_cls_name not in label_map.keys():
            label_map[query_cls_name] = []
            count_non_existing += 1
            if ',' not in query_cls_name:
                print(f"Class {query_cls_name} not found in label map. Count: {count_non_existing}")
        label_names = label_map[query_cls_name]
        if len(label_names) == 0:
            continue
        n_classes = query_metadata['n_classes'].values[i]
        #     if n_classes <= 15:
        #         continue
        results['q_ind'].append(i)
        results['n_classes'].append(n_classes)
        results['class_name'].append(query_metadata["class_name"].values[i])

        # use the distances to find the hub representations with the smallest distances from the query
        retrieval_inds = mi_matrix[i, :].argsort()[-max_rank:][::-1]
        retrieval_dists = mi_matrix[i, retrieval_inds]

        for j in range(len(retrieval_inds)):
            retrieved_index = retrieval_inds[j]
            d = retrieval_dists[j]
            results[f'score_top_{j + 1}'].append(d)
            #             retrieval_class = labels_class_name[retrieved_index]
            retrieval_class = hub_metadata["class_name"].values[retrieved_index].lower()
            if retrieval_class.startswith(' '):
                retrieval_class = retrieval_class[1:]
            is_match = retrieval_class in label_names
            results[f'is_match_{j + 1}'].append(is_match)
            results[f'class_{j + 1}'].append(retrieval_class)

    results_df = pd.DataFrame(results)

    # Initialize lists to store precision and recall for each i
    avg_precision_at_k = []
    avg_precision_at_k__found = []
    top_k_acc = []
    avg_score__found = []
    std_scores__found = []
    avg_score__not_found = []
    std_scores__not_found = []
    avg_score = []
    std_scores = []

    # Iterate through each rank i (1 to max_rank)
    cur_cols = []
    cur_score_cols = []
    for i in range(max_rank):
        new_match_col = f'is_match_{i + 1}'
        new_score_col = f'score_top_{i + 1}'
        cur_cols.append(new_match_col)
        cur_score_cols.append(new_score_col)
        cur_matches = results_df[cur_cols].values
        cur_precisions = cur_matches.mean(1)
        results_df[f'precision_at_{i + 1}'] = cur_precisions
        cur_avg_precision = np.mean(cur_precisions)
        avg_precision_at_k.append(cur_avg_precision)

        cur_match_found = cur_matches.sum(1) > 0
        cur_match_not_found = cur_matches.sum(1) == 0
        cur_acc = np.mean(cur_match_found)
        top_k_acc.append(cur_acc)
        cur_avg_precision__found = np.mean(cur_precisions[cur_match_found])
        avg_precision_at_k__found.append(cur_avg_precision__found)

        cur_scores = results_df[cur_score_cols].values
        avg_score.append(np.mean(cur_scores.flatten()))
        std_scores.append(np.std(cur_scores.flatten()))
        avg_score__found.append(np.mean(cur_scores[cur_match_found].flatten()))
        std_scores__found.append(np.std(cur_scores[cur_match_found].flatten()))
        avg_score__not_found.append(np.mean(cur_scores[cur_match_not_found].flatten()))
        std_scores__not_found.append(np.std(cur_scores[cur_match_not_found].flatten()))

    # Convert precision and recall lists to a DataFrame for visualization
    metrics_df = pd.DataFrame({
        'Rank': range(1, max_rank + 1),
        'Avg. Precision @ k': avg_precision_at_k,
        'Top k Acc.': top_k_acc,
        'Avg. Precision @ k (Found)': avg_precision_at_k__found,

    })

    return metrics_df, results_df


# endregion: Funcs

# region: Prep Representations


def norm_reprs(r):
    new_r = r.T - r.mean(1)
    new_r = (new_r / (new_r.std(0) + 1e-4)).T
    return new_r


normed_reps = reps.T - reps.mean(1)
normed_reps = (normed_reps / (normed_reps.std(0) + 1e-4)).T
normed_hf_reps = hf_reps.T - hf_reps.mean(1)
normed_hf_reps = (normed_hf_reps / (normed_hf_reps.std(0) + 1e-4)).T

n_models_in_hf = len(hf_metadata['model'].unique())
print(f'Number of HF logits: {len(hf_metadata)}, Number of HF Models: {n_models_in_hf}')

normed_reps = normed_reps.astype(np.float32)
normed_hf_reps = normed_hf_reps.astype(np.float32)

# region: CLIP Representations

clip_imagenet_prompts = [prompt_template.format(class_name) for class_name in IMAGENET_CLASS_LABELS]
clip_inet21k_prompts = [prompt_template.format(class_name) for class_name in np.random.choice(IMAGENET_21K_CLASS_LABELS, size=2500, replace=False)]
clip_hf_prompts = [prompt_template.format(class_name) for class_name in all_hf_classes]

clip_imagenet_metadata = pd.DataFrame(np.array([list(range(len(IMAGENET_CLASS_LABELS))),
                                                IMAGENET_CLASS_LABELS]).T,
                                      columns=['id', 'class_name'])
clip_imagenet_metadata['n_classes'] = 1
clip_hf_metadata = pd.DataFrame(np.array([list(range(len(all_hf_classes))), all_hf_classes]).T,
                                columns=['id', 'class_name'])
clip_hf_metadata['n_classes'] = 1

with torch.no_grad():
    text_tokens_imagenet = clip.tokenize(clip_imagenet_prompts).to(device)
    text_tokens_inet21k = clip.tokenize(clip_inet21k_prompts).to(device)
    text_tokens_hf = clip.tokenize(clip_hf_prompts).to(device)

    text_embeddings_imagenet = clip_model.encode_text(text_tokens_imagenet)
    text_embeddings_imagenet = text_embeddings_imagenet / text_embeddings_imagenet.norm(dim=-1, keepdim=True)
    text_embeddings_inet21k = clip_model.encode_text(text_tokens_inet21k)
    text_embeddings_inet21k = text_embeddings_inet21k / text_embeddings_inet21k.norm(dim=-1, keepdim=True)
    text_embeddings_hf = clip_model.encode_text(text_tokens_hf)
    text_embeddings_hf = text_embeddings_hf / text_embeddings_hf.norm(dim=-1, keepdim=True)

    clip_imagenet_logits = clip_image_representations @ text_embeddings_imagenet.T
    clip_imagenet_probs = F.softmax(clip_imagenet_logits * torch.exp(clip_model.logit_scale), dim=-1)
    clip_inet21k_logits = clip_image_representations @ text_embeddings_inet21k.T
    clip_inet21k_probs = F.softmax(clip_inet21k_logits * torch.exp(clip_model.logit_scale), dim=-1)
    clip_hf_logits = clip_image_representations @ text_embeddings_hf.T
    clip_hf_probs = F.softmax(clip_hf_logits * torch.exp(clip_model.logit_scale), dim=-1)

    clip_imagenet_logits = clip_imagenet_logits.T.detach().cpu().numpy()
    clip_imagenet_probs = clip_imagenet_probs.T.detach().cpu().numpy()
    clip_inet21k_logits = clip_inet21k_logits.T.detach().cpu().numpy()
    clip_inet21k_probs = clip_inet21k_probs.T.detach().cpu().numpy()
    clip_hf_logits = clip_hf_logits.T.detach().cpu().numpy()
    clip_hf_probs = clip_hf_probs.T.detach().cpu().numpy()

normed_clip_hf_reps = clip_hf_logits.T - clip_hf_logits.mean(1)
normed_clip_hf_reps = (normed_clip_hf_reps / (normed_clip_hf_reps.std(0) + 1e-4)).T
normed_clip_imagenet_reps = clip_imagenet_logits.T - clip_imagenet_logits.mean(1)
normed_clip_imagenet_reps = (normed_clip_imagenet_reps / (normed_clip_imagenet_reps.std(0) + 1e-4)).T
normed_clip_inet21k_reps = clip_inet21k_logits.T - clip_inet21k_logits.mean(1)
normed_clip_inet21k_reps = (normed_clip_inet21k_reps / (normed_clip_inet21k_reps.std(0) + 1e-4)).T

normed_clip_hf_reps = normed_clip_hf_reps.astype(np.float32)
normed_clip_imagenet_reps = normed_clip_imagenet_reps.astype(np.float32)
normed_clip_inet21k_reps = normed_clip_inet21k_reps.astype(np.float32)
clip_hf_logits = clip_hf_logits.astype(np.float32)
clip_imagenet_logits = clip_imagenet_logits.astype(np.float32)
clip_inet21k_logits = clip_inet21k_logits.astype(np.float32)

inet21k_indxs_groups = [int(x) for x in np.linspace(0, clip_inet21k_logits.shape[0], 6).astype(int)]
inet21k_indxs_groups = [(inet21k_indxs_groups[i], inet21k_indxs_groups[i + 1]) for i in range(len(inet21k_indxs_groups) - 1)]
# inet21k_groups = [np.random.choice(range(clip_inet21k_logits.shape[0]), size=1500, replace=False).astype(int).tolist() for _ in range(5)]
# endregion: CLIP Representations

# endregion: Prep Representations

all_performed_exps = [
    {
        'dataset': args.dataset,
        'queries': 'CLIP',
        'hub': 'Imagenet',
        'label_map': 'inet2inet',
        'n_hub_probes': 4000,
        'n_query_probes': 30
    },
    {
        'dataset': args.dataset,
        'queries': 'CLIP',
        'hub': 'Imagenet',
        'label_map': 'inet2inet',
        'n_hub_probes': 4000,
        'n_query_probes': 50
    },
    {
        'dataset': args.dataset,
        'queries': 'CLIP',
        'hub': 'Imagenet',
        'label_map': 'inet2inet',
        'n_hub_probes': 4000,
        'n_query_probes': 30
    },
    {
        'dataset': args.dataset,
        'queries': 'CLIP',
        'hub': 'Imagenet',
        'label_map': 'inet2inet',
        'n_hub_probes': 4000,
        'n_query_probes': 50
    },
]


def get_repr_and_metadata(name, hub_name=None):
    if name == "Imagenet":
        return normed_reps, metadata
    elif name == "HuggingFace":
        return normed_hf_reps, hf_metadata
    elif name == "CLIP":
        if hub_name == "HuggingFace":
            return normed_clip_hf_reps, clip_hf_metadata
        elif hub_name == "Imagenet":
            return normed_clip_imagenet_reps, clip_imagenet_metadata
    else:
        raise ValueError(f"Invalid name: {name}")


def get_unnormed_repr_and_metadata(name, hub_name=None):
    if name == "Imagenet":
        return reps, metadata
    elif name == "HuggingFace":
        return hf_reps, hf_metadata
    elif name == "CLIP":
        if hub_name == "HuggingFace":
            return clip_hf_logits, clip_hf_metadata
        elif hub_name == "Imagenet":
            return clip_imagenet_logits, clip_imagenet_metadata
    else:
        raise ValueError(f"Invalid name: {name}")


def get_label_mapping(name):
    return eval(f'{name}_label_mapping_dict')


n_iters = 5

# region: ProbeLog INet21K Negs Corr Results

probelog_inet21k_negs_corr_results_df = pd.DataFrame(columns=['dataset', 'queries', 'hub', 'n_hub_probes', 'n_query_probes', 'label_map',
                                                              'top1_acc_mean', 'top1_prec_mean', 'top5_acc_mean', 'top5_prec_mean',
                                                              'top1_acc_std', 'top1_prec_std', 'top5_acc_std', 'top5_prec_std'])

pbar = tqdm(range(len(all_performed_exps) * n_iters))
for exp_dict in all_performed_exps:
    hub_repr, hub_metadata = get_unnormed_repr_and_metadata(exp_dict['hub'])
    query_repr, query_metadata = get_unnormed_repr_and_metadata(exp_dict['queries'], hub_name=exp_dict['hub'])
    label_map = get_label_mapping(exp_dict['label_map'])
    top1_accs, top5_accs = [], []
    top1_prec, top5_prec = [], []
    results_row = deepcopy(exp_dict)
    n_hub_probes = results_row['n_hub_probes']
    n_query_probes = results_row['n_query_probes']
    pbar.set_description_str(f"ProbeLog Corr. || Dataset: {args.dataset}, Queries: {exp_dict['queries']}, Hub: {exp_dict['hub']}")
    for i in range(n_iters):
        cur_inds = np.random.choice(range(hub_repr.shape[1]), size=n_hub_probes, replace=False)
        cur_hub_repr = norm_reprs(hub_repr[:, cur_inds])
        cur_query_repr = norm_reprs(query_repr[:, cur_inds])
        inet21k_start, inet21k_end = inet21k_indxs_groups[i]
        cur_negs_reprs = norm_reprs(clip_inet21k_logits[inet21k_start:inet21k_end, :][:, cur_inds])
        # print(cur_hub_repr.shape, cur_query_repr.shape, cur_negs_reprs.shape)
        metrics_df, cur_res_df = get_retrieval_results__inet21k_negs(cur_hub_repr, cur_query_repr,
                                                                     query_metadata, hub_metadata,
                                                                     negs_reprs=cur_negs_reprs,
                                                                     label_map=label_map,
                                                                     n_probes=n_query_probes, max_rank=10,
                                                                     verbose=False)
        top1_accs.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Top k Acc.'].iloc[0])
        top5_accs.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Top k Acc.'].iloc[0])
        top1_prec.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Avg. Precision @ k'].iloc[0])
        top5_prec.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Avg. Precision @ k'].iloc[0])
        print(f"Top 1 Acc: {np.round(top1_accs[-1], 2)}, Top 5 Acc: {np.round(top5_accs[-1], 2)}")

        pbar.update(1)
        if i == 0:
            cur_res_df.to_csv(f"{save_dir}/errors/{args.dataset}_{exp_dict['label_map']}_top{exp_dict['n_query_probes']}_probelog_inet21k_negs_corr_errors.csv", index=False)

    results_row['top1_acc_mean'] = np.mean(top1_accs)
    results_row['top1_acc_std'] = np.std(top1_accs)
    results_row['top5_acc_mean'] = np.mean(top5_accs)
    results_row['top5_acc_std'] = np.std(top5_accs)
    results_row['top1_prec_mean'] = np.mean(top1_prec)
    results_row['top1_prec_std'] = np.std(top1_prec)
    results_row['top5_prec_mean'] = np.mean(top5_prec)
    results_row['top5_prec_std'] = np.std(top5_prec)

    probelog_inet21k_negs_corr_results_df = pd.concat([probelog_inet21k_negs_corr_results_df, pd.DataFrame([results_row])], ignore_index=True)
    probelog_inet21k_negs_corr_results_df.to_csv(f"{save_dir}/{args.dataset}_probelog_inet21k_negs_corr_retrieval_results_final.csv", index=False)

probelog_inet21k_negs_corr_results_df.to_csv(f"{save_dir}/{args.dataset}_probelog_inet21k_negs_corr_retrieval_results_final.csv", index=False)

# region: ProbeLog INet21K Negs Corr Results

# region: SoftWPMI Results

softwpmi_results_df = pd.DataFrame(columns=['dataset', 'queries', 'hub', 'n_hub_probes', 'n_query_probes', 'label_map',
                                            'top1_acc_mean', 'top1_prec_mean', 'top5_acc_mean', 'top5_prec_mean',
                                            'top1_acc_std', 'top1_prec_std', 'top5_acc_std', 'top5_prec_std'])

pbar = tqdm(range(len(all_performed_exps) * n_iters))
for exp_dict in all_performed_exps:
    hub_repr, hub_metadata = get_unnormed_repr_and_metadata(exp_dict['hub'])
    query_repr, query_metadata = get_unnormed_repr_and_metadata(exp_dict['queries'], hub_name=exp_dict['hub'])
    label_map = get_label_mapping(exp_dict['label_map'])
    top1_accs, top5_accs = [], []
    top1_prec, top5_prec = [], []
    results_row = deepcopy(exp_dict)
    n_hub_probes = results_row['n_hub_probes']
    n_query_probes = results_row['n_query_probes']
    pbar.set_description_str(f"SoftWPMI || Dataset: {args.dataset}, Queries: {exp_dict['queries']}, Hub: {exp_dict['hub']}")
    for i in range(n_iters):
        cur_inds = np.random.choice(range(hub_repr.shape[1]), size=n_hub_probes, replace=False)
        cur_hub_repr = hub_repr[:, cur_inds]
        cur_query_repr = query_repr[:, cur_inds]
        inet21k_start, inet21k_end = inet21k_indxs_groups[i]
        cur_negs_reprs = clip_inet21k_logits[inet21k_start:inet21k_end, :][:, cur_inds]
        metrics_df, cur_res_df = get_retrieval_results__softwpmi(cur_hub_repr, cur_query_repr,
                                                                 query_metadata, hub_metadata,
                                                                 negs_reprs=cur_negs_reprs,
                                                                 label_map=label_map,
                                                                 n_probes=n_query_probes, max_rank=10,
                                                                 verbose=False)
        top1_accs.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Top k Acc.'].iloc[0])
        top5_accs.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Top k Acc.'].iloc[0])
        top1_prec.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Avg. Precision @ k'].iloc[0])
        top5_prec.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Avg. Precision @ k'].iloc[0])
        print(f"Top 1 Acc: {np.round(top1_accs[-1], 2)}, Top 5 Acc: {np.round(top5_accs[-1], 2)}")

        pbar.update(1)
        if i == 0:
            cur_res_df.to_csv(f"{save_dir}/errors/{args.dataset}_{exp_dict['label_map']}_top{exp_dict['n_query_probes']}_softwpmi_errors.csv", index=False)

    results_row['top1_acc_mean'] = np.mean(top1_accs)
    results_row['top1_acc_std'] = np.std(top1_accs)
    results_row['top5_acc_mean'] = np.mean(top5_accs)
    results_row['top5_acc_std'] = np.std(top5_accs)
    results_row['top1_prec_mean'] = np.mean(top1_prec)
    results_row['top1_prec_std'] = np.std(top1_prec)
    results_row['top5_prec_mean'] = np.mean(top5_prec)
    results_row['top5_prec_std'] = np.std(top5_prec)

    softwpmi_results_df = pd.concat([softwpmi_results_df, pd.DataFrame([results_row])], ignore_index=True)
    softwpmi_results_df.to_csv(f"{save_dir}/{args.dataset}_softwpmi_retrieval_results_final.csv", index=False)

softwpmi_results_df.to_csv(f"{save_dir}/{args.dataset}_softwpmi_retrieval_results_final.csv", index=False)

# endregion: SoftWPMI Results

# region: WPMI Results

wpmi_results_df = pd.DataFrame(columns=['dataset', 'queries', 'hub', 'n_hub_probes', 'n_query_probes', 'label_map',
                                        'top1_acc_mean', 'top1_prec_mean', 'top5_acc_mean', 'top5_prec_mean',
                                        'top1_acc_std', 'top1_prec_std', 'top5_acc_std', 'top5_prec_std'])

pbar = tqdm(range(len(all_performed_exps) * n_iters))
for exp_dict in all_performed_exps:
    hub_repr, hub_metadata = get_unnormed_repr_and_metadata(exp_dict['hub'])
    query_repr, query_metadata = get_unnormed_repr_and_metadata(exp_dict['queries'], hub_name=exp_dict['hub'])
    label_map = get_label_mapping(exp_dict['label_map'])
    top1_accs, top5_accs = [], []
    top1_prec, top5_prec = [], []
    results_row = deepcopy(exp_dict)
    n_hub_probes = results_row['n_hub_probes']
    n_query_probes = results_row['n_query_probes']
    pbar.set_description_str(f"WPMI || Dataset: {args.dataset}, Queries: {exp_dict['queries']}, Hub: {exp_dict['hub']}")
    for i in range(n_iters):
        cur_inds = np.random.choice(range(hub_repr.shape[1]), size=n_hub_probes, replace=False)
        cur_hub_repr = hub_repr[:, cur_inds]
        cur_query_repr = query_repr[:, cur_inds]
        inet21k_start, inet21k_end = inet21k_indxs_groups[i]
        cur_negs_reprs = clip_inet21k_logits[inet21k_start:inet21k_end, :][:, cur_inds]
        metrics_df, cur_res_df = get_retrieval_results__wpmi(cur_hub_repr, cur_query_repr,
                                                             query_metadata, hub_metadata,
                                                             negs_reprs=cur_negs_reprs,
                                                             label_map=label_map,
                                                             n_probes=n_query_probes, max_rank=10,
                                                             verbose=False)
        top1_accs.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Top k Acc.'].iloc[0])
        top5_accs.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Top k Acc.'].iloc[0])
        top1_prec.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Avg. Precision @ k'].iloc[0])
        top5_prec.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Avg. Precision @ k'].iloc[0])
        print(f"Top 1 Acc: {np.round(top1_accs[-1], 2)}, Top 5 Acc: {np.round(top5_accs[-1], 2)}")

        pbar.update(1)
        if i == 0:
            cur_res_df.to_csv(f"{save_dir}/errors/{args.dataset}_{exp_dict['label_map']}_top{exp_dict['n_query_probes']}_wpmi_errors.csv", index=False)

    results_row['top1_acc_mean'] = np.mean(top1_accs)
    results_row['top1_acc_std'] = np.std(top1_accs)
    results_row['top5_acc_mean'] = np.mean(top5_accs)
    results_row['top5_acc_std'] = np.std(top5_accs)
    results_row['top1_prec_mean'] = np.mean(top1_prec)
    results_row['top1_prec_std'] = np.std(top1_prec)
    results_row['top5_prec_mean'] = np.mean(top5_prec)
    results_row['top5_prec_std'] = np.std(top5_prec)

    wpmi_results_df = pd.concat([wpmi_results_df, pd.DataFrame([results_row])], ignore_index=True)
    wpmi_results_df.to_csv(f"{save_dir}/{args.dataset}_wpmi_retrieval_results_final.csv", index=False)

wpmi_results_df.to_csv(f"{save_dir}/{args.dataset}_wpmi_retrieval_results_final.csv", index=False)

# endregion: WPMI Results

# region: All Probes INet21k Negs Results


all_probes_inet21k_negs_results_df = pd.DataFrame(columns=['dataset', 'queries', 'hub', 'n_hub_probes', 'n_query_probes', 'label_map',
                                                           'top1_acc_mean', 'top1_prec_mean', 'top5_acc_mean', 'top5_prec_mean',
                                                           'top1_acc_std', 'top1_prec_std', 'top5_acc_std', 'top5_prec_std'])

pbar = tqdm(range(len(all_performed_exps) * n_iters))
for exp_dict in all_performed_exps:
    hub_repr, hub_metadata = get_repr_and_metadata(exp_dict['hub'])
    query_repr, query_metadata = get_repr_and_metadata(exp_dict['queries'], hub_name=exp_dict['hub'])
    label_map = get_label_mapping(exp_dict['label_map'])
    top1_accs, top5_accs = [], []
    top1_prec, top5_prec = [], []
    results_row = deepcopy(exp_dict)
    n_hub_probes = results_row['n_hub_probes']
    results_row['n_query_probes'] = None
    pbar.set_description_str(f"Dataset: {args.dataset}, Queries: {exp_dict['queries']}, Hub: {exp_dict['hub']}")
    for i in range(n_iters):
        cur_inds = np.random.choice(range(hub_repr.shape[1]), size=n_hub_probes, replace=False)
        cur_hub_repr = hub_repr[:, cur_inds]
        cur_query_repr = query_repr[:, cur_inds]
        inet21k_start, inet21k_end = inet21k_indxs_groups[i]
        cur_negs_reprs = normed_clip_inet21k_reps[inet21k_start:inet21k_end, :][:, cur_inds]
        metrics_df, cur_res_df = get_retrieval_results__all_probes__inet21k_negs(cur_hub_repr, cur_query_repr,
                                                                                 query_metadata, hub_metadata,
                                                                                 negs_reprs=cur_negs_reprs,
                                                                                 label_map=label_map,
                                                                                 max_rank=10, verbose=False)
        top1_accs.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Top k Acc.'].iloc[0])
        top5_accs.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Top k Acc.'].iloc[0])
        top1_prec.append(metrics_df.loc[metrics_df['Rank'] == 1, 'Avg. Precision @ k'].iloc[0])
        top5_prec.append(metrics_df.loc[metrics_df['Rank'] == 5, 'Avg. Precision @ k'].iloc[0])
        print(f"Top 1 Acc: {np.round(top1_accs[-1], 2)}, Top 5 Acc: {np.round(top5_accs[-1], 2)}")

        pbar.update(1)
        if i == 0:
            cur_res_df.to_csv(f"{save_dir}/errors/{args.dataset}_{exp_dict['label_map']}_top{exp_dict['n_query_probes']}_all_probes_inet21k_negs_errors.csv", index=False)

    results_row['top1_acc_mean'] = np.mean(top1_accs)
    results_row['top1_acc_std'] = np.std(top1_accs)
    results_row['top5_acc_mean'] = np.mean(top5_accs)
    results_row['top5_acc_std'] = np.std(top5_accs)
    results_row['top1_prec_mean'] = np.mean(top1_prec)
    results_row['top1_prec_std'] = np.std(top1_prec)
    results_row['top5_prec_mean'] = np.mean(top5_prec)
    results_row['top5_prec_std'] = np.std(top5_prec)

    all_probes_inet21k_negs_results_df = pd.concat([all_probes_inet21k_negs_results_df, pd.DataFrame([results_row])], ignore_index=True)
    all_probes_inet21k_negs_results_df.to_csv(f"{save_dir}/{args.dataset}_all_probes_inet21k_negs_retrieval_results_final.csv", index=False)

all_probes_inet21k_negs_results_df.to_csv(f"{save_dir}/{args.dataset}_all_probes_inet21k_negs_retrieval_results_final.csv", index=False)

# endregion: All Probes INet21k Negs Results
