import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import math
import torch


def dice_coefficient_with_threshold(
    array1, array2, threshold1, threshold2, device="cuda:0"
):
    # Convert inputs to PyTorch tensors
    array1 = torch.tensor(array1, device=device)
    array2 = torch.tensor(array2, device=device)
    threshold1 = torch.tensor(threshold1, device=device)
    threshold2 = torch.tensor(threshold2, device=device)

    # Apply thresholds to both arrays
    array1 = array1 >= threshold1
    array2 = array2 >= threshold2

    intersection = (array1 & array2).float().sum()
    total_pixels_array1 = array1.sum()
    total_pixels_array2 = array2.sum()

    dice_score = (2.0 * intersection) / (total_pixels_array1 + total_pixels_array2)

    dice_score = dice_score.item()
    if math.isnan(dice_score):
        return 0
    else:
        return dice_score


train_step = 500

run_name = "apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[9.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256"
layer_name = "transformer.h.10.mlp.c_fc"

target_categories = [
    "politics.left_wing",
    "politics.right_wing",
    "science.physics",
    "science.chemistry",
    "science.biology",
    "science.math",
    "sports.football",
    "sports.olympics",
    "sports.tennis",
    "technology.software",
    "technology.artificial_intelligence",
    "technology.blockchain",
    "technology.space_exploration",
    "musicians.rock",
    "musicians.pop",
    "musicians.hip_hop",
    "actors.drama",
    "actors.action",
    "actors.comedy",
    "history.ancient",
    "history.medieval",
    "history.modern",
]

output_folder_numpy_arrays = (
    f"maps/hierarchial_numpy_arrays/{train_step}_{layer_name}_{run_name}"
)

filenames = []

for target_category in target_categories:
    numpy_array_filename = f"{target_category}_{train_step}_{layer_name}_{run_name}.npy"

    output_filename_numpy = os.path.join(
        output_folder_numpy_arrays, numpy_array_filename
    )
    filenames.append(output_filename_numpy)

all_selectivity_maps = [np.load(f) for f in filenames]

all_dice_scores = []
for x in tqdm(all_selectivity_maps):
    for y in all_selectivity_maps:
        dice_score = dice_coefficient_with_threshold(
            array1=x, array2=y, threshold1=0.06, threshold2=0.06
        )
        all_dice_scores.append(dice_score)

all_dice_scores = np.array(all_dice_scores).reshape(
    len(all_selectivity_maps), len(all_selectivity_maps)
)

fig = plt.figure(figsize=(15, 15))

sns.heatmap(
    all_dice_scores, xticklabels=target_categories, yticklabels=target_categories
)
# plt.yticks(rotation=30)
# plt.xticks(rotation=30)
fig.savefig("dice_score_matrix.png")
