import json

import numpy as np
from tqdm import tqdm

dataset_name = "imagenet"
root = "."
num_classes = 1000


with open(f"{root}/out/{dataset_name}/aligment_scores_per_class.json", "r") as f:
    alignment_scores = json.load(f)

slice_by_class = {i: [] for i in range(num_classes)}
for i, slice_info in enumerate(alignment_scores):
    class_idx = slice_info["selected_class"]

    slice_by_class[class_idx].append(slice_info)


context_dict = {i: [] for i in range(num_classes)}
target_dict = {i: [] for i in range(num_classes)}
bias_dict = {i: [] for i in range(num_classes)}


bias_threshold = -0.0
target_threshold = 0.0

for selected_class in tqdm(range(num_classes)):
    bias_latent_indices = []
    bias_scores = []
    alignment_scores = np.array([slice_info["alignment_score"] for slice_info in slice_by_class[selected_class]])
    for slice_info in slice_by_class[selected_class]:
        score = (slice_info["alignment_score"] - alignment_scores.mean()) / alignment_scores.std()
        if score <= bias_threshold:
            context_dict[selected_class].append(slice_info)
            bias_latent_indices.append(slice_info["latent_idx"])
        elif score > target_threshold:
            target_dict[selected_class].append(slice_info)
        slice_info["alignment_score_norm"] = score

    bias_latent_indices = np.array(bias_latent_indices)
    bias_scores = np.array(bias_scores)
    threshold = bias_scores.mean() + 1 * bias_scores.std()
    high_bias = np.where(bias_scores > threshold)[0]
    for i in high_bias:
        slice_info = context_dict[selected_class][i]
        bias_dict[selected_class].append(slice_info)
