# %% [markdown]
# Evaluate the metrics In and Out of Distribution for SDDOIA & Co

# %%
import torch
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import sys
from utils.train import convert_to_categories, compute_coverage
from datasets.boia import BOIA
from datasets.sddoia import SDDOIA
from datasets.minikandinsky import MiniKandinsky
from datasets.kandinsky import Kandinsky
from datasets.shortcutmnist import SHORTMNIST
from datasets.addmnist import ADDMNIST
from datasets.clipkandinsky import CLIPKandinsky
from datasets.clipshortcutmnist import CLIPSHORTMNIST
from datasets.clipboia import CLIPBOIA
from datasets.clevr import CLEVR
from models.boiadpl import BoiaDPL
from models.boialtn import BOIALTN
from models.boiann import BOIAnn
from models.boiacbm import BoiaCBM
from models.mnistcbm import MnistCBM
from models.mnistdpl import MnistDPL
from models.mnistdsl import MnistDSL
from models.mnistltn import MnistLTN
from models.mnistnn import MNISTnn
from models.mnistdsldpl import MnistDSLDPL
from models.minikanddpl import MiniKandDPL
from models.kanddpl import KandDPL
from models.kandcbm import KandCBM
from models.kandltn import KANDltn
from models.kandnn import KANDnn
from models.clevrcbm import ClevrCBM
from models.clevrdsl import ClevrDSL
from models.clevrdsldpl import ClevrDSLDPL
from models.clevrdpl import CLEVRDPL
from utils.hungarian import permutation_matrix_from_predictions
from sklearn.metrics import confusion_matrix
from argparse import Namespace
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch.nn as nn

# %% [markdown]
# #### CBM model

# %%
class MNISTCBM(nn.Module):
    def __init__(self):
        super(MNISTCBM, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # [32, 28, 28]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [32, 14, 14]
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [64, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [64, 7, 7]
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # [128, 7, 7]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [128, 3, 3]
        )
        self.flatten = nn.Flatten()
        self.fc_individual = nn.Sequential(
            nn.Linear(128 * 3 * 3, 256),  # Processed features for each image
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            # nn.Softmax(dim=1)
        )
        self.fc_aggregate = nn.Sequential(
            # nn.Linear(20, 19, bias=False), # Output range: 0-18 (max sum of two MNIST digits)
            nn.Linear(20, 2, bias=False)
        )

    def forward(self, x):
        features1 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, :28])))
        features2 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, 28:])))
        cs = torch.stack([features1, features2], dim=1)
        combined_features = torch.cat([torch.nn.functional.softmax(features1, dim=-1), torch.nn.functional.softmax(features2, dim=-1)], dim=1)
        # combined_features = torch.nn.functional.softmax(features1, dim=-1).unsqueeze(2).multiply(torch.nn.functional.softmax(features2, dim=-1).unsqueeze(1)).view(features1.shape[0], -1)
        output = torch.softmax(self.fc_aggregate(combined_features), dim=-1)
        pCs = torch.stack([torch.nn.functional.softmax(features1, dim=-1), torch.nn.functional.softmax(features2, dim=-1)], dim=1)
        return {"CS": cs, "YS": output, "pCS": pCs}

# %% [markdown]
# Class containing all the metrics which we are evaluating

# %%
class Metrics:
    def __init__(
        self,
        concept_accuracy,
        label_accuracy,
        concept_f1_macro,
        concept_f1_micro,
        concept_f1_weighted,
        label_f1_macro,
        label_f1_micro,
        label_f1_weighted,
        collapse,
        collapse_hard,
        avg_nll,
    ):
        self.concept_accuracy = concept_accuracy
        self.label_accuracy = label_accuracy
        self.concept_f1_macro = concept_f1_macro
        self.concept_f1_micro = concept_f1_micro
        self.concept_f1_weighted = concept_f1_weighted
        self.label_f1_macro = label_f1_macro
        self.label_f1_micro = label_f1_micro
        self.label_f1_weighted = label_f1_weighted
        self.collapse = collapse
        self.collapse_hard = collapse_hard
        self.avg_nll = avg_nll

    def to_string(self):
        return ", ".join(f"{key}: {value}" for key, value in self.__dict__.items())

class ExtendedMetrics(Metrics):
    def __init__(
        self,
        concept_accuracy,
        label_accuracy,
        concept_f1_macro,
        concept_f1_micro,
        concept_f1_weighted,
        label_f1_macro,
        label_f1_micro,
        label_f1_weighted,
        collapse,
        collapse_hard,
        avg_nll,
        beta_f1,
        beta_acc
    ):
        super(ExtendedMetrics, self).__init__(
            concept_accuracy,
            label_accuracy,
            concept_f1_macro,
            concept_f1_micro,
            concept_f1_weighted,
            label_f1_macro,
            label_f1_micro,
            label_f1_weighted,
            collapse,
            collapse_hard,
            avg_nll,
        )
        self.beta_f1 = beta_f1
        self.beta_acc = beta_acc

    @staticmethod
    def fromMetric(metric, beta_f1, beta_acc):
        return ExtendedMetrics(
            metric.concept_accuracy,
            metric.label_accuracy,
            metric.concept_f1_macro,
            metric.concept_f1_micro,
            metric.concept_f1_weighted,
            metric.label_f1_macro,
            metric.label_f1_micro,
            metric.label_f1_weighted,
            metric.collapse,
            metric.collapse_hard,
            metric.avg_nll,
            beta_f1, 
            beta_acc
        )

class BOIAMetrics(Metrics):
    def __init__(
        self,
        concept_accuracy,
        label_accuracy,
        concept_f1_macro,
        concept_f1_micro,
        concept_f1_weighted,
        label_f1_macro,
        label_f1_micro,
        label_f1_weighted,
        collapse,
        collapse_hard,
        collapse_forward,
        collapse_stop,
        collapse_left,
        collapse_right,
        collapse_hard_forward,
        collapse_hard_stop,
        collapse_hard_left,
        collapse_hard_right,
        mean_collapse,
        mean_hard_collapse,
        avg_nll,
    ):
        super(BOIAMetrics, self).__init__(
            concept_accuracy,
            label_accuracy,
            concept_f1_macro,
            concept_f1_micro,
            concept_f1_weighted,
            label_f1_macro,
            label_f1_micro,
            label_f1_weighted,
            collapse,
            collapse_hard,
            avg_nll,
        )
        self.collapse_forward = collapse_forward
        self.collapse_stop = collapse_stop
        self.collapse_left = collapse_left
        self.collapse_right = collapse_right
        self.collapse_hard_forward = collapse_hard_forward
        self.collapse_hard_stop = collapse_hard_stop
        self.collapse_hard_left = collapse_hard_left
        self.collapse_hard_right = collapse_hard_right
        self.mean_collapse = mean_collapse
        self.mean_hard_collapse = mean_hard_collapse


class KandMetrics(Metrics):
    def __init__(
        self,
        concept_accuracy,
        label_accuracy,
        concept_f1_macro,
        concept_f1_micro,
        concept_f1_weighted,
        label_f1_macro,
        label_f1_micro,
        label_f1_weighted,
        collapse,
        collapse_hard,
        avg_nll,
        collapse_shapes,
        collapse_hard_shapes,
        collapse_color,
        collapse_hard_color,
        mean_collapse,
        mean_collapse_hard,
    ):
        super(KandMetrics, self).__init__(
            concept_accuracy,
            label_accuracy,
            concept_f1_macro,
            concept_f1_micro,
            concept_f1_weighted,
            label_f1_macro,
            label_f1_micro,
            label_f1_weighted,
            collapse,
            collapse_hard,
            avg_nll,
        )
        self.collapse_shapes = collapse_shapes
        self.collapse_hard_shapes = collapse_hard_shapes
        self.collapse_color = collapse_color
        self.collapse_hard_color = collapse_hard_color
        self.mean_collapse = mean_collapse
        self.mean_collapse_hard = mean_collapse_hard


class ClevrMetrics(Metrics):
    def __init__(
        self,
        concept_accuracy,
        label_accuracy,
        concept_f1_macro,
        concept_f1_micro,
        concept_f1_weighted,
        label_f1_macro,
        label_f1_micro,
        label_f1_weighted,
        collapse,
        collapse_hard,
        avg_nll,
        collapse_shapes,
        collapse_hard_shapes,
        collapse_color,
        collapse_hard_color,
        collapse_materials,
        collapse_hard_materials,
        collapse_sizes,
        collapse_hard_sizes,
        mean_collapse,
        mean_collapse_hard,
    ):
        super(ClevrMetrics, self).__init__(
            concept_accuracy,
            label_accuracy,
            concept_f1_macro,
            concept_f1_micro,
            concept_f1_weighted,
            label_f1_macro,
            label_f1_micro,
            label_f1_weighted,
            collapse,
            collapse_hard,
            avg_nll,
        )
        self.collapse_shapes = collapse_shapes
        self.collapse_hard_shapes = collapse_hard_shapes
        self.collapse_color = collapse_color
        self.collapse_hard_color = collapse_hard_color
        self.collapse_materials = collapse_materials
        self.collapse_hard_materials = collapse_hard_materials
        self.collapse_sizes = collapse_sizes
        self.collapse_hard_sizes = collapse_hard_sizes
        self.mean_collapse = mean_collapse
        self.mean_collapse_hard = mean_collapse_hard

# %% [markdown]
# Function used to compute the concept collapse

# %%
def compute_concept_collapse(true_concepts, predicted_concepts, multilabel=False):
    if multilabel:
        true_concepts = convert_to_categories(true_concepts.astype(int))
        predicted_concepts = convert_to_categories(predicted_concepts.astype(int))

    return 1 - compute_coverage(confusion_matrix(true_concepts, predicted_concepts))


def compute_hard_concept_collapse(true_concepts, predicted_concepts, multilabel=False):
    if multilabel:
        true_concepts = convert_to_categories(true_concepts.astype(int))
        predicted_concepts = convert_to_categories(predicted_concepts.astype(int))

    return 1 - compute_coverage_hard(
        confusion_matrix(true_concepts, predicted_concepts)
    )

# %% [markdown]
# Function used to plot confusion matrix

# %%
def plot_confusion_matrix(
    true_labels,
    predicted_labels,
    classes,
    normalize=False,
    title=None,
    is_boia=False,
    cmap=plt.cm.Oranges,
):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    cm = np.zeros((len(classes), len(classes)))
    for i in range(len(true_labels)):
        cm[true_labels[i], predicted_labels[i]] += 1

    if normalize:
        cm = cm.astype("float")
        row_sums = cm.sum(axis=1)[:, np.newaxis]
        cm = np.where(row_sums == 0, 0, cm / row_sums)

    plt.figure(figsize=(8, 6))
    sns.set(font_scale=1.8)
    red_yellow_palette = sns.color_palette("OrRd", as_cmap=True)
    sns.heatmap(
        cm,
        annot=False,
        fmt=".2f" if normalize else "d",
        cmap=red_yellow_palette,
        cbar=True,
        xticklabels=classes,
        yticklabels=classes,
    )
    if title:
        plt.savefig(title, format="pdf")
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# %% [markdown]
# Function used to compute the metrics

# %%
def compute_metrics(
    true_labels,
    predicted_labels,
    true_concepts,
    predicted_concepts,
    avg_nll,
    dataset_name,
    model_name,
    seed,
):

    # multilabel or not
    multilabel_concept = False
    multilabel_label = False

    if dataset_name in ["boia", "sddoia", "clipboia", "clipSDDOIA"]:
        multilabel_concept = True
        multilabel_label = True

    if dataset_name in ["kandinsky", "minikandinsky", "clipkandinsky"]:
        collapse_true_concepts_list = torch.tensor(true_concepts)
        collapse_true_concepts_list = torch.split(collapse_true_concepts_list, 3, dim=1)
        collapse_pred_concepts_list = torch.tensor(predicted_concepts)
        collapse_pred_concepts_list = torch.split(collapse_pred_concepts_list, 3, dim=1)

        collapse_true_concepts_1 = collapse_true_concepts_list[0].flatten()
        collapse_true_concepts_2 = collapse_true_concepts_list[1].flatten()
        collapse_true_concepts = torch.stack(
            (collapse_true_concepts_1, collapse_true_concepts_2), dim=1
        )
        # to int
        collapse_true_concepts = (
            collapse_true_concepts[:, 0] * 3 + collapse_true_concepts[:, 1]
        )
        collapse_true_concepts = collapse_true_concepts.detach().numpy()

        collapse_pred_concepts_1 = collapse_pred_concepts_list[0].flatten()
        collapse_pred_concepts_2 = collapse_pred_concepts_list[1].flatten()
        collapse_pred_concepts = torch.stack(
            (collapse_pred_concepts_1, collapse_pred_concepts_2), dim=1
        )
        # to int
        collapse_pred_concepts = (
            collapse_pred_concepts[:, 0] * 3 + collapse_pred_concepts[:, 1]
        )
        collapse_pred_concepts = collapse_pred_concepts.detach().numpy()

        # total collapse
        collapse = compute_concept_collapse(
            collapse_true_concepts, collapse_pred_concepts, multilabel_concept
        )

        collapse_hard = compute_hard_concept_collapse(
            collapse_true_concepts, collapse_pred_concepts, multilabel_concept
        )
    elif dataset_name in ["boia", "sddoia", "clipboia", "clipSDDOIA"]:
        # additional metrics for boia and sddoia
        collapse_forward, collapse_hard_forward = compute_concept_collapse(
            true_concepts[:, :3], predicted_concepts[:, :3], True
        ), compute_hard_concept_collapse(
            true_concepts[:, :3], predicted_concepts[:, :3], True
        )
        collapse_stop, collapse_hard_stop = compute_concept_collapse(
            true_concepts[:, 3:9], predicted_concepts[:, 3:9], True
        ), compute_hard_concept_collapse(
            true_concepts[:, 3:9], predicted_concepts[:, 3:9], True
        )
        collapse_left, collapse_hard_left = compute_concept_collapse(
            true_concepts[:, 9:15], predicted_concepts[:, 9:15], True
        ), compute_hard_concept_collapse(
            true_concepts[:, 9:15], predicted_concepts[:, 9:15], True
        )
        collapse_right, collapse_hard_right = compute_concept_collapse(
            true_concepts[:, 15:21], predicted_concepts[:, 15:21], True
        ), compute_hard_concept_collapse(
            true_concepts[:, 15:21], predicted_concepts[:, 15:21], True
        )

        mean_collapse, mean_hard_collapse = np.mean(
            [collapse_forward, collapse_stop, collapse_left, collapse_right]
        ), np.mean(
            [
                collapse_hard_forward,
                collapse_hard_stop,
                collapse_hard_left,
                collapse_hard_right,
            ]
        )

    elif dataset_name in ["minikandinsky", "kandinsky", "clipkandinsky"]:
        # additional metrics for boia and sddoia
        collapse_color, collapse_hard_color = compute_concept_collapse(
            true_concepts[:, 3:6].reshape(-1),
            predicted_concepts[:, 3:6].reshape(-1),
            False,
        ), compute_hard_concept_collapse(
            true_concepts[:, 3:6].reshape(-1),
            predicted_concepts[:, 3:6].reshape(-1),
            False,
        )
        collapse_shapes, collapse_hard_shapes = compute_concept_collapse(
            true_concepts[:, :3].reshape(-1),
            predicted_concepts[:, :3].reshape(-1),
            False,
        ), compute_hard_concept_collapse(
            true_concepts[:, :3].reshape(-1),
            predicted_concepts[:, :3].reshape(-1),
            False,
        )

        mean_collapse, mean_collapse_hard = np.mean(
            [collapse_color, collapse_shapes]
        ), np.mean([collapse_hard_color, collapse_hard_shapes])
    elif dataset_name in ["clevr"]:
        mask_color = true_concepts[:, :, 0].reshape(-1) != -1
        mask_shapes = true_concepts[:, :, 1].reshape(-1) != -1
        mask_materials = true_concepts[:, :, 2].reshape(-1) != -1
        mask_sizes = true_concepts[:, :, 3].reshape(-1) != -1

        filtered_true_colors = true_concepts[:, :, 0].reshape(-1)[mask_color]
        filtered_predicted_colors = predicted_concepts[:, :, 0].reshape(-1)[mask_color]

        filtered_true_shapes = true_concepts[:, :, 1].reshape(-1)[mask_shapes]
        filtered_predicted_shapes = predicted_concepts[:, :, 1].reshape(-1)[mask_shapes]

        filtered_true_materials = true_concepts[:, :, 2].reshape(-1)[mask_materials]
        filtered_predicted_materials = predicted_concepts[:, :, 2].reshape(-1)[mask_materials]

        filtered_true_sizes = true_concepts[:, :, 3].reshape(-1)[mask_sizes]
        filtered_predicted_sizes = predicted_concepts[:, :, 3].reshape(-1)[mask_sizes]

        # Compute collapses
        collapse_color = compute_concept_collapse(
            filtered_true_colors,
            filtered_predicted_colors,
            False,
        )
        collapse_shapes = compute_concept_collapse(
            filtered_true_shapes,
            filtered_predicted_shapes,
            False,
        )
        collapse_materials = compute_concept_collapse(
            filtered_true_materials,
            filtered_predicted_materials,
            False,
        )
        collapse_sizes = compute_concept_collapse(
            filtered_true_sizes,
            filtered_predicted_sizes,
            False,
        )

        mean_collapse, mean_collapse_hard = np.mean(
            [collapse_color, collapse_shapes, collapse_materials, collapse_sizes]
        ), 0
    else:
        # total collapse
        collapse = compute_concept_collapse(
            true_concepts, predicted_concepts, multilabel_concept
        )

        collapse_hard = collapse #compute_hard_concept_collapse(
        #    true_concepts, predicted_concepts, multilabel_concept
        # )

    if multilabel_concept:
        concept_accuracy, concept_f1_macro, concept_f1_micro, concept_f1_weighted = (
            0,
            0,
            0,
            0,
        )

        for i in range(true_concepts.shape[1]):
            concept_accuracy += accuracy_score(true_concepts[i], predicted_concepts[i])
            concept_f1_macro += f1_score(
                true_concepts[i], predicted_concepts[i], average="macro"
            )
            concept_f1_micro += f1_score(
                true_concepts[i], predicted_concepts[i], average="micro"
            )
            concept_f1_weighted += f1_score(
                true_concepts[i], predicted_concepts[i], average="weighted"
            )

        concept_accuracy = concept_accuracy / true_concepts.shape[1]
        concept_f1_macro = concept_f1_macro / true_concepts.shape[1]
        concept_f1_micro = concept_f1_micro / true_concepts.shape[1]
        concept_f1_weighted = concept_f1_weighted / true_concepts.shape[1]

        label_accuracy, label_f1_macro, label_f1_micro, label_f1_weighted = 0, 0, 0, 0
    elif dataset_name in ["kandinsky", "minikandinsky", "clipkandinsky"]:
        concept_accuracy_color = accuracy_score(
            true_concepts[:, 3:6].reshape(-1), predicted_concepts[:, 3:6].reshape(-1)
        )
        concept_f1_macro_color = f1_score(
            true_concepts[:, 3:6].reshape(-1),
            predicted_concepts[:, 3:6].reshape(-1),
            average="macro",
        )
        concept_f1_micro_color = f1_score(
            true_concepts[:, 3:6].reshape(-1),
            predicted_concepts[:, 3:6].reshape(-1),
            average="micro",
        )
        concept_f1_weighted_color = f1_score(
            true_concepts[:, 3:6].reshape(-1),
            predicted_concepts[:, 3:6].reshape(-1),
            average="weighted",
        )

        concept_accuracy_shape = accuracy_score(
            true_concepts[:, :3].reshape(-1), predicted_concepts[:, :3].reshape(-1)
        )
        concept_f1_macro_shape = f1_score(
            true_concepts[:, :3].reshape(-1),
            predicted_concepts[:, :3].reshape(-1),
            average="macro",
        )
        concept_f1_micro_shape = f1_score(
            true_concepts[:, :3].reshape(-1),
            predicted_concepts[:, :3].reshape(-1),
            average="micro",
        )
        concept_f1_weighted_shape = f1_score(
            true_concepts[:, :3].reshape(-1),
            predicted_concepts[:, :3].reshape(-1),
            average="weighted",
        )

        concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape])
        concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape])
        concept_f1_micro = np.mean([concept_f1_micro_color, concept_f1_micro_shape])
        concept_f1_weighted = np.mean(
            [concept_f1_weighted_color, concept_f1_weighted_shape]
        )
    elif dataset_name in ["clevr"]:

        mask_color = true_concepts[:, :, 0].reshape(-1) != -1
        mask_shapes = true_concepts[:, :, 1].reshape(-1) != -1
        mask_materials = true_concepts[:, :, 2].reshape(-1) != -1
        mask_sizes = true_concepts[:, :, 3].reshape(-1) != -1

        filtered_true_colors = true_concepts[:, :, 0].reshape(-1)[mask_color]
        filtered_predicted_colors = predicted_concepts[:, :, 0].reshape(-1)[mask_color]

        filtered_true_shapes = true_concepts[:, :, 1].reshape(-1)[mask_shapes]
        filtered_predicted_shapes = predicted_concepts[:, :, 1].reshape(-1)[mask_shapes]

        filtered_true_materials = true_concepts[:, :, 2].reshape(-1)[mask_materials]
        filtered_predicted_materials = predicted_concepts[:, :, 2].reshape(-1)[mask_materials]

        filtered_true_sizes = true_concepts[:, :, 3].reshape(-1)[mask_sizes]
        filtered_predicted_sizes = predicted_concepts[:, :, 3].reshape(-1)[mask_sizes]
        
        concept_accuracy_color = accuracy_score(
            filtered_true_colors, filtered_predicted_colors,
        )
        concept_f1_macro_color = f1_score(
            filtered_true_colors, filtered_predicted_colors,
            average="macro",
        )
        concept_f1_micro_color = f1_score(
            filtered_true_colors, filtered_predicted_colors,
            average="micro",
        )
        concept_f1_weighted_color = f1_score(
            filtered_true_colors, filtered_predicted_colors,
            average="weighted",
        )

        concept_accuracy_shape = accuracy_score(
            filtered_true_shapes, filtered_predicted_shapes,
        )
        concept_f1_macro_shape = f1_score(
            filtered_true_shapes, filtered_predicted_shapes,
            average="macro",
        )
        concept_f1_macro_shape = f1_score(
            filtered_true_shapes, filtered_predicted_shapes,
            average="macro",
        )
        concept_f1_micro_shape = f1_score(
            filtered_true_shapes, filtered_predicted_shapes,
            average="micro",
        )
        concept_f1_weighted_shape = f1_score(
            filtered_true_shapes, filtered_predicted_shapes,
            average="weighted",
        )

        concept_accuracy_sizes = accuracy_score(
            filtered_true_sizes, filtered_predicted_sizes,
        )
        concept_f1_macro_sizes = f1_score(
            filtered_true_sizes, filtered_predicted_sizes,
            average="macro",
        )
        concept_f1_macro_sizes = f1_score(
            filtered_true_sizes, filtered_predicted_sizes,
            average="macro",
        )
        concept_f1_micro_sizes = f1_score(
            filtered_true_sizes, filtered_predicted_sizes,
            average="micro",
        )
        concept_f1_weighted_sizes = f1_score(
            filtered_true_sizes, filtered_predicted_sizes,
            average="weighted",
        )

        concept_accuracy_materials = accuracy_score(
            filtered_true_materials, filtered_predicted_materials,
        )
        concept_f1_macro_materials = f1_score(
            filtered_true_materials, filtered_predicted_materials,
            average="macro",
        )
        concept_f1_macro_materials = f1_score(
            filtered_true_materials, filtered_predicted_materials,
            average="macro",
        )
        concept_f1_micro_materials = f1_score(
            filtered_true_materials, filtered_predicted_materials,
            average="micro",
        )
        concept_f1_weighted_materials = f1_score(
            filtered_true_shapes, filtered_predicted_shapes,
            average="weighted",
        )

        concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape, concept_accuracy_materials, concept_accuracy_sizes])
        concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape, concept_f1_macro_materials, concept_f1_macro_sizes])
        concept_f1_micro = np.mean([concept_f1_micro_color, concept_f1_micro_shape, concept_f1_micro_materials, concept_f1_micro_sizes])
        concept_f1_weighted = np.mean(
            [concept_f1_weighted_color, concept_f1_weighted_shape, concept_f1_weighted_materials, concept_f1_weighted_sizes]
        )
    else:
        concept_accuracy = accuracy_score(true_concepts, predicted_concepts)
        concept_f1_macro = f1_score(true_concepts, predicted_concepts, average="macro")
        concept_f1_micro = f1_score(true_concepts, predicted_concepts, average="micro")
        concept_f1_weighted = f1_score(
            true_concepts, predicted_concepts, average="weighted"
        )

    if multilabel_label:
        for i in range(true_labels.shape[1]):
            label_accuracy += accuracy_score(true_labels[i], predicted_labels[i])
            label_f1_macro += f1_score(
                true_labels[i], predicted_labels[i], average="macro"
            )
            label_f1_micro += f1_score(
                true_labels[i], predicted_labels[i], average="micro"
            )
            label_f1_weighted += f1_score(
                true_labels[i], predicted_labels[i], average="weighted"
            )

        label_accuracy = label_accuracy / true_labels.shape[1]
        label_f1_macro = label_f1_macro / true_labels.shape[1]
        label_f1_micro = label_f1_micro / true_labels.shape[1]
        label_f1_weighted = label_f1_weighted / true_labels.shape[1]
    else:
        label_accuracy = accuracy_score(true_labels, predicted_labels)
        label_f1_macro = f1_score(true_labels, predicted_labels, average="macro")
        label_f1_micro = f1_score(true_labels, predicted_labels, average="micro")
        label_f1_weighted = f1_score(true_labels, predicted_labels, average="weighted")

    if dataset_name in ["boia", "sddoia", "clipboia", "clipSDDOIA"]:
        metrics = BOIAMetrics(
            concept_accuracy=concept_accuracy,
            label_accuracy=label_accuracy,
            concept_f1_macro=concept_f1_macro,
            concept_f1_micro=concept_f1_micro,
            concept_f1_weighted=concept_f1_weighted,
            label_f1_macro=label_f1_macro,
            label_f1_micro=label_f1_micro,
            label_f1_weighted=label_f1_weighted,
            collapse=collapse,
            collapse_hard=collapse_hard,
            collapse_forward=collapse_forward,
            collapse_stop=collapse_stop,
            collapse_right=collapse_right,
            collapse_left=collapse_left,
            collapse_hard_forward=collapse_hard_forward,
            collapse_hard_stop=collapse_hard_stop,
            collapse_hard_right=collapse_hard_right,
            collapse_hard_left=collapse_hard_left,
            mean_collapse=mean_collapse,
            mean_hard_collapse=mean_hard_collapse,
            avg_nll=avg_nll,
        )
    elif dataset_name in ["minikandinsky", "kandinsky", "clipkandinsky"]:
        metrics = KandMetrics(
            concept_accuracy=concept_accuracy,
            label_accuracy=label_accuracy,
            concept_f1_macro=concept_f1_macro,
            concept_f1_micro=concept_f1_micro,
            concept_f1_weighted=concept_f1_weighted,
            label_f1_macro=label_f1_macro,
            label_f1_micro=label_f1_micro,
            label_f1_weighted=label_f1_weighted,
            collapse=collapse,
            collapse_hard=collapse_hard,
            avg_nll=avg_nll,
            collapse_shapes=collapse_shapes,
            collapse_color=collapse_color,
            collapse_hard_shapes=collapse_hard_shapes,
            mean_collapse_hard=mean_collapse_hard,
            mean_collapse=mean_collapse,
            collapse_hard_color=collapse_hard_color,
        )
    elif dataset_name in ["clevr"]:
        metrics = ClevrMetrics(
            concept_accuracy=concept_accuracy,
            label_accuracy=label_accuracy,
            concept_f1_macro=concept_f1_macro,
            concept_f1_micro=concept_f1_micro,
            concept_f1_weighted=concept_f1_weighted,
            label_f1_macro=label_f1_macro,
            label_f1_micro=label_f1_micro,
            label_f1_weighted=label_f1_weighted,
            collapse=0.0,
            collapse_hard=0.0,
            avg_nll=avg_nll,
            collapse_shapes=collapse_shapes,
            collapse_color=collapse_color,
            collapse_materials=collapse_materials,
            collapse_sizes=collapse_sizes,
            collapse_hard_shapes=0.0,
            collapse_hard_color=0.0,
            collapse_hard_materials=0.0,
            collapse_hard_sizes=0.0,
            mean_collapse_hard=0.0,
            mean_collapse=mean_collapse,
        )
    else:
        metrics = Metrics(
            concept_accuracy=concept_accuracy,
            label_accuracy=label_accuracy,
            concept_f1_macro=concept_f1_macro,
            concept_f1_micro=concept_f1_micro,
            concept_f1_weighted=concept_f1_weighted,
            label_f1_macro=label_f1_macro,
            label_f1_micro=label_f1_micro,
            label_f1_weighted=label_f1_weighted,
            collapse=collapse,
            collapse_hard=collapse_hard,
            avg_nll=avg_nll,
        )

    if dataset_name in ["shortmnist", "mnistAddition"]:
        plot_confusion_matrix(
            true_concepts,
            predicted_concepts,
            classes=[i for i in range(10)],
            normalize=True,
            title=f"{model_name}_{dataset_name}_{seed}.pdf",
            is_boia=True,
        )
    elif dataset_name in ["boia", "sddoia"]:

        plot_confusion_matrix(
            convert_to_categories(true_concepts[:, :3].astype(int)),
            convert_to_categories(predicted_concepts[:, :3].astype(int)),
            ["" for i in range(2**3)],
            True,
            f"{model_name}_{dataset_name}_{seed}_forward.pdf",
        )
        plot_confusion_matrix(
            convert_to_categories(true_concepts[:, 3:9].astype(int)),
            convert_to_categories(predicted_concepts[:, 3:9].astype(int)),
            ["" for i in range(2**6)],
            True,
            f"{model_name}_{dataset_name}_{seed}_stop.pdf",
        )
        plot_confusion_matrix(
            convert_to_categories(true_concepts[:, 9:15].astype(int)),
            convert_to_categories(predicted_concepts[:, 9:15].astype(int)),
            ["" for i in range(2**6)],
            True,
            f"{model_name}_{dataset_name}_{seed}_left.pdf",
        )
        plot_confusion_matrix(
            convert_to_categories(true_concepts[:, 15:21].astype(int)),
            convert_to_categories(predicted_concepts[:, 15:21].astype(int)),
            ["" for i in range(2**6)],
            True,
            f"{model_name}_{dataset_name}_{seed}_right.pdf",
        )
    elif dataset_name in ["kandinsky", "minikandinsky"]:
        plot_confusion_matrix(
            true_concepts,
            predicted_concepts,
            classes=[i for i in range(10)],
            normalize=True,
            title=f"{model_name}_{dataset_name}_{seed}.pdf",
        )
    elif dataset_name in ["clevr"]:
        # TODO
        pass
        # plot_confusion_matrix(
        #     true_concepts,
        #     predicted_concepts,
        #     classes=[i for i in range(10)],
        #     normalize=True,
        #     title=f"{model_name}_{dataset_name}_{seed}.pdf",
        # )
    return metrics

# %% [markdown]
# Load the right dataset and the right model

# %%
def get_dataset(datasetname, args):
    if datasetname.lower() == "boia":
        return BOIA(args)
    if datasetname.lower() == "sddoia":
        return SDDOIA(args)
    if datasetname.lower() == "minikandinsky":
        return MiniKandinsky(args)
    if datasetname.lower() == "kandinsky":
        return Kandinsky(args)
    if datasetname.lower() == "shortmnist":
        return SHORTMNIST(args)
    if datasetname.lower() == "clipkandinsky":
        return CLIPKandinsky(args)
    if datasetname.lower() == "clipshortmnist":
        return CLIPSHORTMNIST(args)
    if datasetname.lower() == "clipboia":
        return CLIPBOIA(args)
    if datasetname.lower() == "clipSDDOIA":
        return CLIPSDDOIA(args)
    if datasetname.lower() == "addmnist":
        return ADDMNIST(args)
    if datasetname.lower() == "clevr":
        return CLEVR(args)

    raise NotImplementedError(f"Dataset {datasetname} missing")

# %%
def get_model(modelname, encoder, args):
    if modelname.lower() == "boiadpl":
        return BoiaDPL(encoder=encoder, args=args)
    if modelname.lower() == "SDDOIAdpl":
        return SDDOIADPL(encoder=encoder, args=args)
    if modelname.lower() == "boialtn":
        return BOIALTN(encoder=encoder, args=args)
    if modelname.lower() == "SDDOIAltn":
        return SDDOIALTN(encoder=encoder, args=args)
    if modelname.lower() == "boiann":
        return BOIAnn(encoder=encoder, args=args)
    if modelname.lower() == "SDDOIAnn":
        return SDDOIAnn(encoder=encoder, args=args)
    if modelname.lower() == "boiacbm":
        return BoiaCBM(encoder=encoder, args=args)
    if modelname.lower() == "SDDOIAcbm":
        return SDDOIACBM(encoder=encoder, args=args)
    if modelname.lower() == "minikanddpl":
        return MiniKandDPL(encoder=encoder, args=args)
    if modelname.lower() == "kandltn":
        return KANDltn(encoder=encoder, args=args)
    if modelname.lower() == "kandnn":
        return KANDnn(encoder=encoder, args=args)
    if modelname.lower() == "kanddpl":
        return KandDPL(encoder=encoder, args=args)
    if modelname.lower() == "kandcbm":
        return KandCBM(encoder=encoder, args=args)
    if modelname.lower() == "mnistdpl":
        return MnistDPL(encoder=encoder, args=args)
    if modelname.lower() == "mnistdsl":
        return MnistDSL(encoder=encoder, args=args)
    if modelname.lower() == "mnistltn":
        return MnistLTN(encoder=encoder, args=args)
    if modelname.lower() == "mnistnn":
        return MNISTnn(encoder=encoder, args=args)
    if modelname.lower() == "mnistcbm":
        return MNISTCBM()
    if modelname.lower() == "mnistdsldpl":
        return MnistDSLDPL(encoder=encoder, args=args)
    if modelname.lower() == "clevrcbm":
        return ClevrCBM()
    if modelname.lower() == "clevrdsldpl":
        return ClevrDSLDPL(encoder=encoder, args=args)
    if modelname.lower() == "clevrdsl":
        return ClevrDSL(encoder=encoder, args=args)
    if modelname.lower() == "clevrdpl":
        return CLEVRDPL(encoder=encoder, args=args)

    raise NotImplementedError(f"Model {modelname} missing")

# %%
args = Namespace(
    backbone="conceptizer",  #
    preprocess=0,
    finetuning=0,
    batch_size=256,
    n_epochs=20,
    validate=1,
    dataset="clevr",
    lr=0.001,
    exp_decay=0.99,
    warmup_steps=1,
    wandb=None,
    task="clevr",
    boia_model="ce",
    model="clevrdsl",
    c_sup=1,
    which_c=[-1],
    joint=False,
    boia_ood_knowledge=True,
    splitted=False,
    eps_sym=0.5,
    eps_rul=0.5
)

# get dataset
dataset = get_dataset(args.dataset, args)
# get model
model = get_model(modelname=args.model, encoder=dataset.get_backbone()[0], args=args)

# set cpu for the moment
model.device = "cuda:0"

model.to(model.device)
if hasattr(model, "encoder"):
    model.encoder.to(model.device)
if hasattr(model, "net"):
    model.net.to(model.device)

model

# %% [markdown]
# Define the seeds of the models

# %%
seeds = [1011, 1213, 1415, 1617, 1819, 2021, 2223]
model_path = f"path"


# %% [markdown]
# Loop through the dataset and retrive concepts and labels

# %%
def get_concepts_and_labels_boia(out_labels, out_concepts):
    batch_size = out_labels.size(0)

    predicted_labels, predicted_concepts = [], []

    for idx_batch in range(batch_size):
        prob_labels = torch.split(out_labels[idx_batch], 2)
        prob_concepts = torch.split(out_concepts[idx_batch], 2)

        tmp_lab, tmp_conc = [], []

        for l_lab in prob_labels:
            tmp_lab.append(torch.argmax(l_lab, dim=0))
        for l_conc in prob_concepts:
            tmp_conc.append(torch.argmax(l_conc, dim=0))

        predicted_labels.append(torch.tensor([tmp_lab]))
        predicted_concepts.append(torch.tensor([tmp_conc]))

    predicted_labels = torch.concatenate(predicted_labels, dim=0)
    predicted_concepts = torch.concatenate(predicted_concepts, dim=0)

    return predicted_labels, predicted_concepts

# %%
def get_concepts_and_labels_mnist(
    out_labels, out_concepts, true_concepts, is_ood=False
):

    # filtering out the extended support
    # if not is_ood:
    #     for i in range(19):
    #         if i in [6, 10, 12]:
    #             continue
    #         out_labels[:, i] = 0

    predicted_labels = torch.argmax(out_labels, dim=-1)
    predicted_concepts = torch.argmax(out_concepts, dim=-1)

    predicted_concepts = predicted_concepts.view(predicted_concepts.numel())
    refactored_true_concepts = true_concepts.view(true_concepts.numel())

    return predicted_labels, predicted_concepts, refactored_true_concepts

# %%
def get_concepts_and_labels_kand(out_labels, out_concepts, true_concepts):

    # take the prediction
    predicted_labels = torch.argmax(out_labels, dim=1)

    # stack colors and shapes on top of each other
    refactored_true_concepts = torch.split(true_concepts, 1, dim=1)
    refactored_true_concepts = torch.concatenate(
        refactored_true_concepts, dim=0
    ).squeeze(1)

    # take the prediction
    predicted_concepts_list = torch.split(out_concepts, 3, dim=2)
    predicted_concepts = []
    # take the argmax
    for pc in predicted_concepts_list:
        predicted_concepts.append(torch.argmax(pc, dim=2))
    predicted_concepts = torch.stack(predicted_concepts, dim=2)

    # make them the same dimension as the groundtruth
    predicted_concepts = torch.split(predicted_concepts, 1, dim=1)
    predicted_concepts = torch.concatenate(predicted_concepts, dim=0).squeeze(1)

    return predicted_labels, torch.squeeze(predicted_concepts), refactored_true_concepts

# %%
def get_concepts_and_labels_clevr(out_dict, true_concepts, is_dsl):
    # take the prediction
    if not is_dsl:
        predicted_labels = torch.argmax(out_dict["YS"], dim=1)
    else:
        predicted_labels = out_dict["PRED"]

    refactored_true_concepts = true_concepts.view(true_concepts.shape[0], 4, -1)
    pCS = out_dict['pCS'].view(out_dict['pCS'].shape[0], 4, -1)

    # get out gt and pt

    def conditional_argmax(tensor):
        max_vals, argmax_vals = torch.max(tensor, dim=-1)  # Get max values and indices
        argmax_vals[max_vals == -1] = -1  # Set argmax to -1 if max value is -1
        return argmax_vals

    gt_colors, gt_shapes, gt_materials, gt_sizes = refactored_true_concepts[:, :, :8], refactored_true_concepts[:, :, 8:11], refactored_true_concepts[:, :, 11:13], refactored_true_concepts[:, :, 13:15]
    pt_colors, pt_shapes, pt_materials, pt_sizes = pCS[:, :, :8], pCS[:, :, 8:11], pCS[:, :, 11:13], pCS[:, :, 13:15]

    gt_colors, gt_shapes, gt_materials, gt_sizes = (
        conditional_argmax(gt_colors),
        conditional_argmax(gt_shapes),
        conditional_argmax(gt_materials),
        conditional_argmax(gt_sizes)
    )

    pt_colors, pt_shapes, pt_materials, pt_sizes = (
        conditional_argmax(pt_colors),
        conditional_argmax(pt_shapes),
        conditional_argmax(pt_materials),
        conditional_argmax(pt_sizes)
    )

    refactored_true_concepts = torch.stack(
        [gt_colors, gt_shapes, gt_materials, gt_sizes],
        dim = -1,
    )
    predicted_concepts = torch.stack(
        [pt_colors, pt_shapes, pt_materials, pt_sizes],
        dim = -1
    )

    return predicted_labels, predicted_concepts, refactored_true_concepts

# %%
def retrive_concepts_and_labels(model, dataset, dataset_name, model_name, is_ood=False, is_dsl=False):

    true_labels, predicted_labels, true_concepts, predicted_concepts = [], [], [], []

    nll_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    for i, data in enumerate(tqdm(dataset)):
        images, labels, concepts = data
        images, labels, concepts = (
            images.to(model.device),
            labels.to(model.device),
            concepts.to(model.device),
        )

        # filtering out the middle rules supervision
        if dataset_name in ["kandinsky", "minikandinsky", "clipkandinsky"]:
            labels = labels[:, -1]

        if is_dsl:
            out_dict = model(images, eval=True)
        else:
            out_dict = model(images)

        out_label, out_concept = None, None

        if dataset_name in ["boia", "sddoia", "clipboia", "clipSDDOIA"]:
            class_predictions = torch.split(out_dict["YS"], 2, dim=1)
            assert len(class_predictions) == 4

            loss = 0
            for i, _pred in enumerate(class_predictions):
                loss += criterion(_pred.float().cpu(), labels[:, i].long().cpu())
            loss /= len(class_predictions)
        else:
            if model_name in ["mnistdsl", "mnistdsldpl"]:
                c1 = torch.argmax(out_dict["pCS"][:, 0, :], dim=-1)
                c2 = torch.argmax(out_dict["pCS"][:, 1, :], dim=-1)
                loss = torch.nn.functional.nll_loss(out_dict["KNOWLEDGE"][c1, c2].float().cpu(), labels.long().cpu(), reduction="sum")
            else:
                # TODO
                # print( torch.nn.functional.nll_loss(out_dict["YS"].float().cpu(), labels.long().cpu()))
                loss = torch.nn.functional.nll_loss(out_dict["YS"].float().cpu(), labels.long().cpu(), reduction="sum")

        nll_loss += loss.item()
        # print(nll_loss)

        if dataset_name in ["boia", "sddoia", "clipboia", "clipSDDOIA"]:
            out_label, out_concept = get_concepts_and_labels_boia(
                out_dict["YS"], out_dict["pCS"]
            )
        elif dataset_name in ["shortmnist", "clipshortmnist", "addmnist"]:
            if is_dsl:
                _, out_concept, concepts = get_concepts_and_labels_mnist(
                    out_dict["PRED"], out_dict["pCS"], concepts, is_ood
                )
                out_label = out_dict["PRED"].cpu().squeeze()
            else:
                out_label, out_concept, concepts = get_concepts_and_labels_mnist(
                    out_dict["YS"], out_dict["pCS"], concepts, is_ood
                )
        elif dataset_name in ["kandinsky", "minikandinsky", "clipkandinsky"]:
            out_label, out_concept, concepts = get_concepts_and_labels_kand(
                out_dict["YS"], out_dict["pCS"], concepts
            )
        elif dataset_name in ["clevr"]:
            out_label, out_concept, concepts = get_concepts_and_labels_clevr(
                out_dict, concepts, is_dsl
            )

        true_labels.append(labels.cpu().numpy())
        true_concepts.append(concepts.cpu().numpy())

        predicted_labels.append(out_label.detach().cpu().numpy())
        predicted_concepts.append(out_concept.cpu().numpy())

        # break # REMOVEME

    # concatenate
    true_labels = np.concatenate(true_labels, axis=0)
    predicted_labels = np.concatenate(predicted_labels, axis=0)
    true_concepts = np.concatenate(true_concepts, axis=0)
    predicted_concepts = np.concatenate(predicted_concepts, axis=0)

    print(nll_loss, len(dataset.dataset))
    avg_nll = nll_loss / len(dataset.dataset)

    assert true_labels.shape == predicted_labels.shape
    assert true_concepts.shape == predicted_concepts.shape, f"{true_concepts.shape} {predicted_concepts.shape}"

    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll

# %%
def evaluate(
    model, test_set, dataset_name, model_name, ood_set=None, ood_set_2=None, hungarian=False, train_set=None, is_dsl=False
):  # TODO: define attributes

    # List of metics
    in_metrics_list = []
    ood_metrics_list = []
    ood_metrics_2_list = []

    n_files = 0

    # Loop through seeds
    for seed in seeds:
        print("Doing", seed, "...")

        to_add = ""
        if "cbm" in model_path:
            to_add = "_False_20"#"_partial_sup"  # "_joint" # ""
        print("TO ADD:", to_add)

        if "cbm" in model_path: 
            current_model_path = f"{model_path}_{seed}{to_add}.pth"
        else:
            current_model_path = f"{model_path}{seed}.pth"
        print(current_model_path)

        if not os.path.exists(current_model_path):
            print(f"{current_model_path} is missing...")
            continue
        else:
            print(f"Loading {current_model_path}...")

        n_files += 1

        try:
            # retrieve the status dict
            model_state_dict = torch.load(current_model_path)
            # Load the model status dict
            model.load_state_dict(model_state_dict)
        except Exception as e:
            print(e)
            continue

        if dataset_name == "shortmnist":
            model = model.float()

        model.eval()

        w_acc, w_f1 = None, None

        if hungarian:
            pi = get_hungarian_permutation(model, train_set, dataset_name, model_name, metric="correlation", is_dsl=is_dsl)
            ind_data = retrive_concepts_and_labels_hungarian(model, pi, test_set, dataset_name, model_name, is_dsl=is_dsl)
            if model_name not in ["mnistdpl", "clevrdpl"]:
                if "cbm" in model_name:
                    w = get_cbm_knowledge(model.fc_aggregate, model.device, dataset_name)
                else:
                    w = torch.argmax(torch.nn.functional.softmax(model.weights, dim=2), dim=2)
                w_aligned = np.dot(pi.T, np.dot(w.cpu().numpy(), pi)).flatten()
                w_gt = get_gt_knowledge(dataset_name).flatten()
                w_acc = accuracy_score(w_gt, w_aligned)
                w_f1 = f1_score(w_gt, w_aligned, average="macro")
        else:
            ind_data = retrive_concepts_and_labels(model, test_set, dataset_name, model_name, is_dsl=is_dsl)

        if ood_set is not None:
            out_data = retrive_concepts_and_labels(
                model, ood_set, dataset_name, model_name, is_ood=True
            )

        if ood_set_2 is not None:
            out_data_2 = retrive_concepts_and_labels(
                model, ood_set_2, dataset_name, model_name, is_ood=True
            )

        in_metrics = compute_metrics(*ind_data, dataset_name, model_name, seed)
        if w_acc is not None and w_f1 is not None:
            in_metrics = ExtendedMetrics.fromMetric(in_metrics, w_f1, w_acc)
        in_metrics_list.append(in_metrics)

        if ood_set is not None:
            ood_metrics = compute_metrics(*out_data, dataset_name, model_name, seed)
            ood_metrics_list.append(ood_metrics)

        if ood_set_2 is not None:
            ood_metrics_2 = compute_metrics(*out_data_2, dataset_name, model_name, seed)
            ood_metrics_2_list.append(ood_metrics_2)

        torch.cuda.empty_cache()

    if n_files == 1:
        print("IN", in_metrics.to_string())
        print("OOD", ood_metrics.to_string())

    assert n_files > 1, "At least 2 files to compare"

    # Compute standard deviation for each metric
    for key in vars(in_metrics_list[0]):  # the key are always the same
        # skip hidden elements
        if not key.startswith("_"):
            # retrieve the list of values
            in_metric_values = [getattr(metrics, key) for metrics in in_metrics_list]
            ood_metric_values = [getattr(metrics, key) for metrics in ood_metrics_list]
            ood_metric_2_values = [
                getattr(metrics, key) for metrics in ood_metrics_2_list
            ]

            # convert lists to NumPy arrays
            in_metric_values_arr = np.array(in_metric_values)
            ood_metric_values_arr = np.array(ood_metric_values)
            ood_metric_values_2_arr = np.array(ood_metric_2_values)

            # Compute the standard deviation
            in_metric_std_dev = np.std(in_metric_values_arr)
            ood_metric_std_dev = np.std(ood_metric_values_arr)
            ood_metric_2_std_dev = np.std(ood_metric_values_2_arr)

            # Compute the mean
            in_metric_std_mean = np.mean(in_metric_values_arr)
            ood_metric_std_mean = np.mean(ood_metric_values_arr)
            ood_metric_2_std_mean = np.mean(ood_metric_values_2_arr)

            print(
                "\n{} (In): ${:.2f} \pm {:.2f}$".format(
                    key.replace("_", " ").title(),
                    round(in_metric_std_mean, 2),
                    round(in_metric_std_dev, 2),
                )
            )

            if ood_set is not None:
                print(
                    "{} (OOD): ${:.2f} \pm {:.2f}$".format(
                        key.replace("_", " ").title(),
                        round(ood_metric_std_mean, 2),
                        round(ood_metric_std_dev, 2),
                    )
                )

            if ood_set_2 is not None:
                print(
                    "{} (OOD 2): ${:.2f} \pm {:.2f}$".format(
                        key.replace("_", " ").title(),
                        round(ood_metric_2_std_mean, 2),
                        round(ood_metric_2_std_dev, 2),
                    )
                )

# %% [markdown]
# ### Hungarian Gamma

# %%
import numpy as np
from scipy.optimize import linear_sum_assignment

def rearrange_predictions_with_confusion_clevr(pred, gt):
    N, _, _ = pred.shape

    global_cost_matrix = np.zeros((4, 4))

    for i in range(N):

        local_cost_matrix = np.zeros((4, 4))
        for j in range(4):
            for k in range(4):
                matches = (pred[i, :, j] == gt[i, :, k]) & (gt[i, :, k] != -1)
                local_cost_matrix[j, k] = matches.sum()

        global_cost_matrix += local_cost_matrix

    row_ind, col_ind = linear_sum_assignment(-global_cost_matrix)
    pred_rearranged = np.copy(pred)
    for i in range(N):
        pred_rearranged[i] = pred[i, :, col_ind]

    return pred_rearranged, col_ind

# %%
def get_hungarian_permutation(model, dataset, dataset_name, model_name, metric="correlation", is_dsl=False):
    _, _, true_concepts, predicted_concepts, _ = retrive_concepts_and_labels(model, dataset, dataset_name, model_name, False, is_dsl=is_dsl)

    if dataset_name == "clevr":
        changed_pred_concepts, perm_idx = rearrange_predictions_with_confusion_clevr(predicted_concepts, true_concepts)
        
        perm_color = permutation_matrix_from_predictions(
            changed_pred_concepts[:, :, 0].flatten(), true_concepts[:, :, 0].flatten(), 8
        ).numpy()
        perm_shapes = permutation_matrix_from_predictions(
            changed_pred_concepts[:, :, 1].flatten(), true_concepts[:, :, 1].flatten(), 8
        ).numpy()
        perm_material = permutation_matrix_from_predictions(
            changed_pred_concepts[:, :, 2].flatten(), true_concepts[:, :, 2].flatten(), 8
        ).numpy()
        perm_sizes = permutation_matrix_from_predictions(
            changed_pred_concepts[:, :, 3].flatten(), true_concepts[:, :, 3].flatten(), 8
        ).numpy()

        return (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes)
    else:
        n_classes = 10 if dataset_name == "addmnist" else 2

        return permutation_matrix_from_predictions(
            predicted_concepts, true_concepts, n_classes
        ).numpy()

def retrive_concepts_and_labels_hungarian(model, perm_matrix, dataset, dataset_name, model_name, is_dsl=False):
    true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll = retrive_concepts_and_labels(model, dataset, dataset_name, model_name, False, is_dsl=is_dsl)
    
    if dataset_name == "clevr":
        (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes) = perm_matrix

        for i in range(predicted_concepts.shape[0]):
            predicted_concepts[i] = predicted_concepts[i, :, perm_idx]

        predicted_colors = perm_color[predicted_concepts[:, :, 0]]
        predicted_colors = np.argmax(predicted_colors, axis=-1)

        predicted_shapes = perm_color[predicted_concepts[:, :, 1]]
        predicted_shapes = np.argmax(predicted_shapes, axis=-1)

        predicted_materials = perm_color[predicted_concepts[:, :, 2]]
        predicted_materials = np.argmax(predicted_materials, axis=-1)

        predicted_sizes = perm_color[predicted_concepts[:, :, 3]]
        predicted_sizes = np.argmax(predicted_sizes, axis=-1)

        predicted_concepts = np.stack(
            [predicted_colors, predicted_shapes, predicted_materials, predicted_sizes],
            axis=-1
        )
        
    else:

        predicted_concepts = perm_matrix[predicted_concepts]
        predicted_concepts = np.argmax(predicted_concepts, axis=1)

    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll    

# %%
def get_gt_knowledge(dataset_name):
    if dataset_name == "addmnist":
        w = []
        for i in range(10):
            for j in range(10):
                w.append((i + j) % 2)
        return np.array(w)
    else: 
        pass


# %%
def get_cbm_knowledge(w, device, dataset_name):
    if dataset_name == "clevr":
        pass
    knowledge = torch.zeros((10, 10))
    for i in range(10):
        for j in range(10):
            x = torch.nn.functional.one_hot(torch.tensor([i]), num_classes=10).float().to(device)
            y = torch.nn.functional.one_hot(torch.tensor([j]), num_classes=10).float().to(device)
            xy = torch.cat([x, y], dim=-1)
            # xy = x.unsqueeze(2).multiply(y.unsqueeze(1)).view(x.shape[0], -1)
            knowledge[i, j] = torch.argmax(w(xy), dim=-1)
    return knowledge

# %% [markdown]
# Run all the things

# %%
# Get loaders
train_loader, val_loader, test_loader = dataset.get_data_loaders()
# Get ood set if it exists
ood_loader = getattr(dataset, "ood_loader", None)
# ood_ambulance = getattr(dataset, "ood_loader_2", None) # getattr(dataset, "ood_loader_ambulance", None)

# Evaluate
evaluate(
    model,
    test_loader,
    args.dataset,
    model_name=args.model,
    ood_set=ood_loader,
    ood_set_2=None,
    hungarian=False,#True,
    train_set=train_loader,
    is_dsl=False#True
)

# %%

# model = MNISTCBM()

# model.load_state_dict(torch.load(f"{model_path}_{1011}_True.pth"))

# layer_weights = model.fc_aggregate[0].state_dict()['weight']

# torch.save(layer_weights, "linganguliguliguli_linear.pth")


