{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate the metrics In and Out of Distribution for SDDOIA & Co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "\n",
    "from utils.train import convert_to_categories, compute_coverage\n",
    "from datasets.boia import BOIA\n",
    "from datasets.sddoia import SDDOIA\n",
    "from datasets.minikandinsky import MiniKandinsky\n",
    "from datasets.kandinsky import Kandinsky\n",
    "from datasets.shortcutmnist import SHORTMNIST\n",
    "from datasets.addmnist import ADDMNIST\n",
    "from datasets.clipkandinsky import CLIPKandinsky\n",
    "from datasets.clipshortcutmnist import CLIPSHORTMNIST\n",
    "from datasets.clipboia import CLIPBOIA\n",
    "from datasets.clevr import CLEVR\n",
    "from models.boiadpl import BoiaDPL\n",
    "from models.boialtn import BOIALTN\n",
    "from models.boiann import BOIAnn\n",
    "from models.boiacbm import BoiaCBM\n",
    "from models.mnistcbm import MnistCBM\n",
    "from models.mnistdpl import MnistDPL\n",
    "from models.mnistdsl import MnistDSL\n",
    "from models.mnistltn import MnistLTN\n",
    "from models.mnistnn import MNISTnn\n",
    "from models.mnistdsldpl import MnistDSLDPL\n",
    "from models.minikanddpl import MiniKandDPL\n",
    "from models.kanddpl import KandDPL\n",
    "from models.kandcbm import KandCBM\n",
    "from models.kandltn import KANDltn\n",
    "from models.kandnn import KANDnn\n",
    "from models.clevrcbm import ClevrCBM\n",
    "from models.clevrdsl import ClevrDSL\n",
    "from models.clevrdsldpl import ClevrDSLDPL\n",
    "from models.clevrdpl import CLEVRDPL\n",
    "from utils.hungarian import permutation_matrix_from_predictions\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from utils.jrs_utils import find_boia_permutation, sample_boia_config\n",
    "from argparse import Namespace\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### CBM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNISTCBM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNISTCBM, self).__init__()\n",
    "        self.cnn = nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # [32, 28, 28]\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),  # [32, 14, 14]\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [64, 14, 14]\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),  # [64, 7, 7]\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # [128, 7, 7]\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),  # [128, 3, 3]\n",
    "        )\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc_individual = nn.Sequential(\n",
    "            nn.Linear(128 * 3 * 3, 256),  # Processed features for each image\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(256, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 10),\n",
    "            # nn.Softmax(dim=1)\n",
    "        )\n",
    "        self.fc_aggregate = nn.Sequential(\n",
    "            # nn.Linear(20, 19, bias=False), # Output range: 0-18 (max sum of two MNIST digits)\n",
    "            nn.Linear(20, 2, bias=False)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        features1 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, :28])))\n",
    "        features2 = self.fc_individual(self.flatten(self.cnn(x[:, :, :, 28:])))\n",
    "        cs = torch.stack([features1, features2], dim=1)\n",
    "        combined_features = torch.cat([torch.nn.functional.softmax(features1, dim=-1), torch.nn.functional.softmax(features2, dim=-1)], dim=1)\n",
    "        # combined_features = torch.nn.functional.softmax(features1, dim=-1).unsqueeze(2).multiply(torch.nn.functional.softmax(features2, dim=-1).unsqueeze(1)).view(features1.shape[0], -1)\n",
    "        output = torch.softmax(self.fc_aggregate(combined_features), dim=-1)\n",
    "        pCs = torch.stack([torch.nn.functional.softmax(features1, dim=-1), torch.nn.functional.softmax(features2, dim=-1)], dim=1)\n",
    "        return {\"CS\": cs, \"YS\": output, \"pCS\": pCs}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Class containing all the metrics which we are evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Metrics:\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        avg_nll,\n",
    "    ):\n",
    "        self.concept_accuracy = concept_accuracy\n",
    "        self.label_accuracy = label_accuracy\n",
    "        self.concept_f1_macro = concept_f1_macro\n",
    "        self.concept_f1_micro = concept_f1_micro\n",
    "        self.concept_f1_weighted = concept_f1_weighted\n",
    "        self.label_f1_macro = label_f1_macro\n",
    "        self.label_f1_micro = label_f1_micro\n",
    "        self.label_f1_weighted = label_f1_weighted\n",
    "        self.collapse = collapse\n",
    "        self.collapse_hard = collapse_hard\n",
    "        self.avg_nll = avg_nll\n",
    "\n",
    "    def to_string(self):\n",
    "        return \", \".join(f\"{key}: {value}\" for key, value in self.__dict__.items())\n",
    "\n",
    "class ExtendedMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        avg_nll,\n",
    "        beta_f1,\n",
    "        beta_acc\n",
    "    ):\n",
    "        super(ExtendedMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            label_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            label_f1_macro,\n",
    "            label_f1_micro,\n",
    "            label_f1_weighted,\n",
    "            collapse,\n",
    "            collapse_hard,\n",
    "            avg_nll,\n",
    "        )\n",
    "        self.beta_f1 = beta_f1\n",
    "        self.beta_acc = beta_acc\n",
    "\n",
    "    @staticmethod\n",
    "    def fromMetric(metric, beta_f1, beta_acc):\n",
    "        return ExtendedMetrics(\n",
    "            metric.concept_accuracy,\n",
    "            metric.label_accuracy,\n",
    "            metric.concept_f1_macro,\n",
    "            metric.concept_f1_micro,\n",
    "            metric.concept_f1_weighted,\n",
    "            metric.label_f1_macro,\n",
    "            metric.label_f1_micro,\n",
    "            metric.label_f1_weighted,\n",
    "            metric.collapse,\n",
    "            metric.collapse_hard,\n",
    "            metric.avg_nll,\n",
    "            beta_f1, \n",
    "            beta_acc\n",
    "        )\n",
    "    \n",
    "\n",
    "class BOIAMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        collapse_forward,\n",
    "        collapse_stop,\n",
    "        collapse_left,\n",
    "        collapse_right,\n",
    "        collapse_hard_forward,\n",
    "        collapse_hard_stop,\n",
    "        collapse_hard_left,\n",
    "        collapse_hard_right,\n",
    "        mean_collapse,\n",
    "        mean_hard_collapse,\n",
    "        avg_nll,\n",
    "    ):\n",
    "        super(BOIAMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            label_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            label_f1_macro,\n",
    "            label_f1_micro,\n",
    "            label_f1_weighted,\n",
    "            collapse,\n",
    "            collapse_hard,\n",
    "            avg_nll,\n",
    "        )\n",
    "        self.collapse_forward = collapse_forward\n",
    "        self.collapse_stop = collapse_stop\n",
    "        self.collapse_left = collapse_left\n",
    "        self.collapse_right = collapse_right\n",
    "        self.collapse_hard_forward = collapse_hard_forward\n",
    "        self.collapse_hard_stop = collapse_hard_stop\n",
    "        self.collapse_hard_left = collapse_hard_left\n",
    "        self.collapse_hard_right = collapse_hard_right\n",
    "        self.mean_collapse = mean_collapse\n",
    "        self.mean_hard_collapse = mean_hard_collapse\n",
    "\n",
    "\n",
    "class BOIAExtendedMetrics(BOIAMetrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        collapse_forward,\n",
    "        collapse_stop,\n",
    "        collapse_left,\n",
    "        collapse_right,\n",
    "        collapse_hard_forward,\n",
    "        collapse_hard_stop,\n",
    "        collapse_hard_left,\n",
    "        collapse_hard_right,\n",
    "        mean_collapse,\n",
    "        mean_hard_collapse,\n",
    "        avg_nll,\n",
    "        beta_f1,\n",
    "        beta_acc\n",
    "    ):\n",
    "        super(BOIAExtendedMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            label_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            label_f1_macro,\n",
    "            label_f1_micro,\n",
    "            label_f1_weighted,\n",
    "            collapse,\n",
    "            collapse_hard,\n",
    "            collapse_forward,\n",
    "            collapse_stop,\n",
    "            collapse_left,\n",
    "            collapse_right,\n",
    "            collapse_hard_forward,\n",
    "            collapse_hard_stop,\n",
    "            collapse_hard_left,\n",
    "            collapse_hard_right,\n",
    "            mean_collapse,\n",
    "            mean_hard_collapse,\n",
    "            avg_nll,\n",
    "        )\n",
    "        self.beta_f1 = beta_f1\n",
    "        self.beta_acc = beta_acc\n",
    "\n",
    "    @staticmethod\n",
    "    def fromMetric(metric, beta_f1, beta_acc):\n",
    "        return BOIAExtendedMetrics(\n",
    "            metric.concept_accuracy,\n",
    "            metric.label_accuracy,\n",
    "            metric.concept_f1_macro,\n",
    "            metric.concept_f1_micro,\n",
    "            metric.concept_f1_weighted,\n",
    "            metric.label_f1_macro,\n",
    "            metric.label_f1_micro,\n",
    "            metric.label_f1_weighted,\n",
    "            metric.collapse,\n",
    "            metric.collapse_hard,\n",
    "            metric.collapse_forward,\n",
    "            metric.collapse_stop,\n",
    "            metric.collapse_left,\n",
    "            metric.collapse_right,\n",
    "            metric.collapse_hard_forward,\n",
    "            metric.collapse_hard_stop,\n",
    "            metric.collapse_hard_left,\n",
    "            metric.collapse_hard_right,\n",
    "            metric.mean_collapse,\n",
    "            metric.mean_hard_collapse,\n",
    "            metric.avg_nll,\n",
    "            beta_f1, \n",
    "            beta_acc\n",
    "        )\n",
    "\n",
    "class KandMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        avg_nll,\n",
    "        collapse_shapes,\n",
    "        collapse_hard_shapes,\n",
    "        collapse_color,\n",
    "        collapse_hard_color,\n",
    "        mean_collapse,\n",
    "        mean_collapse_hard,\n",
    "    ):\n",
    "        super(KandMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            label_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            label_f1_macro,\n",
    "            label_f1_micro,\n",
    "            label_f1_weighted,\n",
    "            collapse,\n",
    "            collapse_hard,\n",
    "            avg_nll,\n",
    "        )\n",
    "        self.collapse_shapes = collapse_shapes\n",
    "        self.collapse_hard_shapes = collapse_hard_shapes\n",
    "        self.collapse_color = collapse_color\n",
    "        self.collapse_hard_color = collapse_hard_color\n",
    "        self.mean_collapse = mean_collapse\n",
    "        self.mean_collapse_hard = mean_collapse_hard\n",
    "\n",
    "\n",
    "class ClevrMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        label_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        label_f1_macro,\n",
    "        label_f1_micro,\n",
    "        label_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_hard,\n",
    "        avg_nll,\n",
    "        collapse_shapes,\n",
    "        collapse_hard_shapes,\n",
    "        collapse_color,\n",
    "        collapse_hard_color,\n",
    "        collapse_materials,\n",
    "        collapse_hard_materials,\n",
    "        collapse_sizes,\n",
    "        collapse_hard_sizes,\n",
    "        mean_collapse,\n",
    "        mean_collapse_hard,\n",
    "    ):\n",
    "        super(ClevrMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            label_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            label_f1_macro,\n",
    "            label_f1_micro,\n",
    "            label_f1_weighted,\n",
    "            collapse,\n",
    "            collapse_hard,\n",
    "            avg_nll,\n",
    "        )\n",
    "        self.collapse_shapes = collapse_shapes\n",
    "        self.collapse_hard_shapes = collapse_hard_shapes\n",
    "        self.collapse_color = collapse_color\n",
    "        self.collapse_hard_color = collapse_hard_color\n",
    "        self.collapse_materials = collapse_materials\n",
    "        self.collapse_hard_materials = collapse_hard_materials\n",
    "        self.collapse_sizes = collapse_sizes\n",
    "        self.collapse_hard_sizes = collapse_hard_sizes\n",
    "        self.mean_collapse = mean_collapse\n",
    "        self.mean_collapse_hard = mean_collapse_hard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function used to compute the concept collapse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_concept_collapse(true_concepts, predicted_concepts, multilabel=False):\n",
    "    if multilabel:\n",
    "        true_concepts = convert_to_categories(true_concepts.astype(int))\n",
    "        predicted_concepts = convert_to_categories(predicted_concepts.astype(int))\n",
    "\n",
    "    return 1 - compute_coverage(confusion_matrix(true_concepts, predicted_concepts))\n",
    "\n",
    "\n",
    "def compute_hard_concept_collapse(true_concepts, predicted_concepts, multilabel=False):\n",
    "    if multilabel:\n",
    "        true_concepts = convert_to_categories(true_concepts.astype(int))\n",
    "        predicted_concepts = convert_to_categories(predicted_concepts.astype(int))\n",
    "\n",
    "    return 1 - compute_coverage( # compute_coverage_hard(\n",
    "        confusion_matrix(true_concepts, predicted_concepts)\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function used to plot confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_confusion_matrix(\n",
    "    true_labels,\n",
    "    predicted_labels,\n",
    "    classes,\n",
    "    normalize=False,\n",
    "    title=None,\n",
    "    is_boia=False,\n",
    "    cmap=plt.cm.Oranges,\n",
    "):\n",
    "    \"\"\"\n",
    "    This function prints and plots the confusion matrix.\n",
    "    Normalization can be applied by setting `normalize=True`.\n",
    "    \"\"\"\n",
    "    cm = np.zeros((len(classes), len(classes)))\n",
    "    for i in range(len(true_labels)):\n",
    "        cm[true_labels[i], predicted_labels[i]] += 1\n",
    "\n",
    "    if normalize:\n",
    "        cm = cm.astype(\"float\")\n",
    "        row_sums = cm.sum(axis=1)[:, np.newaxis]\n",
    "        cm = np.where(row_sums == 0, 0, cm / row_sums)\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sns.set(font_scale=1.8)\n",
    "    red_yellow_palette = sns.color_palette(\"OrRd\", as_cmap=True)\n",
    "    sns.heatmap(\n",
    "        cm,\n",
    "        annot=False,\n",
    "        fmt=\".2f\" if normalize else \"d\",\n",
    "        cmap=red_yellow_palette,\n",
    "        cbar=True,\n",
    "        xticklabels=classes,\n",
    "        yticklabels=classes,\n",
    "    )\n",
    "    if title:\n",
    "        plt.savefig(title, format=\"pdf\")\n",
    "    plt.xticks(rotation=0)\n",
    "    plt.yticks(rotation=0)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function used to compute the metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(\n",
    "    true_labels,\n",
    "    predicted_labels,\n",
    "    true_concepts,\n",
    "    predicted_concepts,\n",
    "    avg_nll,\n",
    "    dataset_name,\n",
    "    model_name,\n",
    "    seed,\n",
    "):\n",
    "\n",
    "    # multilabel or not\n",
    "    multilabel_concept = False\n",
    "    multilabel_label = False\n",
    "\n",
    "    if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipSDDOIA\"]:\n",
    "        multilabel_concept = True\n",
    "        multilabel_label = True\n",
    "\n",
    "    if dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "        collapse_true_concepts_list = torch.tensor(true_concepts)\n",
    "        collapse_true_concepts_list = torch.split(collapse_true_concepts_list, 3, dim=1)\n",
    "        collapse_pred_concepts_list = torch.tensor(predicted_concepts)\n",
    "        collapse_pred_concepts_list = torch.split(collapse_pred_concepts_list, 3, dim=1)\n",
    "\n",
    "        collapse_true_concepts_1 = collapse_true_concepts_list[0].flatten()\n",
    "        collapse_true_concepts_2 = collapse_true_concepts_list[1].flatten()\n",
    "        collapse_true_concepts = torch.stack(\n",
    "            (collapse_true_concepts_1, collapse_true_concepts_2), dim=1\n",
    "        )\n",
    "        # to int\n",
    "        collapse_true_concepts = (\n",
    "            collapse_true_concepts[:, 0] * 3 + collapse_true_concepts[:, 1]\n",
    "        )\n",
    "        collapse_true_concepts = collapse_true_concepts.detach().numpy()\n",
    "\n",
    "        collapse_pred_concepts_1 = collapse_pred_concepts_list[0].flatten()\n",
    "        collapse_pred_concepts_2 = collapse_pred_concepts_list[1].flatten()\n",
    "        collapse_pred_concepts = torch.stack(\n",
    "            (collapse_pred_concepts_1, collapse_pred_concepts_2), dim=1\n",
    "        )\n",
    "        # to int\n",
    "        collapse_pred_concepts = (\n",
    "            collapse_pred_concepts[:, 0] * 3 + collapse_pred_concepts[:, 1]\n",
    "        )\n",
    "        collapse_pred_concepts = collapse_pred_concepts.detach().numpy()\n",
    "\n",
    "        # total collapse\n",
    "        collapse = compute_concept_collapse(\n",
    "            collapse_true_concepts, collapse_pred_concepts, multilabel_concept\n",
    "        )\n",
    "\n",
    "        collapse_hard = compute_hard_concept_collapse(\n",
    "            collapse_true_concepts, collapse_pred_concepts, multilabel_concept\n",
    "        )\n",
    "    elif dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipSDDOIA\"]:\n",
    "        # additional metrics for boia and sddoia\n",
    "        collapse_forward, collapse_hard_forward = compute_concept_collapse(\n",
    "            true_concepts[:, :3], predicted_concepts[:, :3], True\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, :3], predicted_concepts[:, :3], True\n",
    "        )\n",
    "        collapse_stop, collapse_hard_stop = compute_concept_collapse(\n",
    "            true_concepts[:, 3:9], predicted_concepts[:, 3:9], True\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, 3:9], predicted_concepts[:, 3:9], True\n",
    "        )\n",
    "        collapse_left, collapse_hard_left = compute_concept_collapse(\n",
    "            true_concepts[:, 9:15], predicted_concepts[:, 9:15], True\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, 9:15], predicted_concepts[:, 9:15], True\n",
    "        )\n",
    "        collapse_right, collapse_hard_right = compute_concept_collapse(\n",
    "            true_concepts[:, 15:21], predicted_concepts[:, 15:21], True\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, 15:21], predicted_concepts[:, 15:21], True\n",
    "        )\n",
    "\n",
    "        mean_collapse, mean_hard_collapse = np.mean(\n",
    "            [collapse_forward, collapse_stop, collapse_left, collapse_right]\n",
    "        ), np.mean(\n",
    "            [\n",
    "                collapse_hard_forward,\n",
    "                collapse_hard_stop,\n",
    "                collapse_hard_left,\n",
    "                collapse_hard_right,\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        collapse = mean_collapse\n",
    "        collapse_hard = mean_collapse\n",
    "\n",
    "    elif dataset_name in [\"minikandinsky\", \"kandinsky\", \"clipkandinsky\"]:\n",
    "        # additional metrics for boia and sddoia\n",
    "        collapse_color, collapse_hard_color = compute_concept_collapse(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            False,\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            False,\n",
    "        )\n",
    "        collapse_shapes, collapse_hard_shapes = compute_concept_collapse(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            False,\n",
    "        ), compute_hard_concept_collapse(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            False,\n",
    "        )\n",
    "\n",
    "        mean_collapse, mean_collapse_hard = np.mean(\n",
    "            [collapse_color, collapse_shapes]\n",
    "        ), np.mean([collapse_hard_color, collapse_hard_shapes])\n",
    "    elif dataset_name in [\"clevr\"]:\n",
    "        mask_color = true_concepts[:, :, 0].reshape(-1) != -1\n",
    "        mask_shapes = true_concepts[:, :, 1].reshape(-1) != -1\n",
    "        mask_materials = true_concepts[:, :, 2].reshape(-1) != -1\n",
    "        mask_sizes = true_concepts[:, :, 3].reshape(-1) != -1\n",
    "\n",
    "        filtered_true_colors = true_concepts[:, :, 0].reshape(-1)[mask_color]\n",
    "        filtered_predicted_colors = predicted_concepts[:, :, 0].reshape(-1)[mask_color]\n",
    "\n",
    "        filtered_true_shapes = true_concepts[:, :, 1].reshape(-1)[mask_shapes]\n",
    "        filtered_predicted_shapes = predicted_concepts[:, :, 1].reshape(-1)[mask_shapes]\n",
    "\n",
    "        filtered_true_materials = true_concepts[:, :, 2].reshape(-1)[mask_materials]\n",
    "        filtered_predicted_materials = predicted_concepts[:, :, 2].reshape(-1)[mask_materials]\n",
    "\n",
    "        filtered_true_sizes = true_concepts[:, :, 3].reshape(-1)[mask_sizes]\n",
    "        filtered_predicted_sizes = predicted_concepts[:, :, 3].reshape(-1)[mask_sizes]\n",
    "\n",
    "        # Compute collapses\n",
    "        collapse_color = compute_concept_collapse(\n",
    "            filtered_true_colors,\n",
    "            filtered_predicted_colors,\n",
    "            False,\n",
    "        )\n",
    "        collapse_shapes = compute_concept_collapse(\n",
    "            filtered_true_shapes,\n",
    "            filtered_predicted_shapes,\n",
    "            False,\n",
    "        )\n",
    "        collapse_materials = compute_concept_collapse(\n",
    "            filtered_true_materials,\n",
    "            filtered_predicted_materials,\n",
    "            False,\n",
    "        )\n",
    "        collapse_sizes = compute_concept_collapse(\n",
    "            filtered_true_sizes,\n",
    "            filtered_predicted_sizes,\n",
    "            False,\n",
    "        )\n",
    "\n",
    "        mean_collapse, mean_collapse_hard = np.mean(\n",
    "            [collapse_color, collapse_shapes, collapse_materials, collapse_sizes]\n",
    "        ), 0\n",
    "    else:\n",
    "        # total collapse\n",
    "        collapse = compute_concept_collapse(\n",
    "            true_concepts, predicted_concepts, multilabel_concept\n",
    "        )\n",
    "\n",
    "        collapse_hard = collapse #compute_hard_concept_collapse(\n",
    "        #    true_concepts, predicted_concepts, multilabel_concept\n",
    "        # )\n",
    "\n",
    "    if multilabel_concept:\n",
    "        concept_accuracy, concept_f1_macro, concept_f1_micro, concept_f1_weighted = (\n",
    "            0,\n",
    "            0,\n",
    "            0,\n",
    "            0,\n",
    "        )\n",
    "\n",
    "        for i in range(true_concepts.shape[1]):\n",
    "            concept_accuracy += accuracy_score(true_concepts[i], predicted_concepts[i])\n",
    "            concept_f1_macro += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"macro\"\n",
    "            )\n",
    "            concept_f1_micro += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"micro\"\n",
    "            )\n",
    "            concept_f1_weighted += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"weighted\"\n",
    "            )\n",
    "\n",
    "        concept_accuracy = concept_accuracy / true_concepts.shape[1]\n",
    "        concept_f1_macro = concept_f1_macro / true_concepts.shape[1]\n",
    "        concept_f1_micro = concept_f1_micro / true_concepts.shape[1]\n",
    "        concept_f1_weighted = concept_f1_weighted / true_concepts.shape[1]\n",
    "\n",
    "        label_accuracy, label_f1_macro, label_f1_micro, label_f1_weighted = 0, 0, 0, 0\n",
    "    elif dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "        concept_accuracy_color = accuracy_score(\n",
    "            true_concepts[:, 3:6].reshape(-1), predicted_concepts[:, 3:6].reshape(-1)\n",
    "        )\n",
    "        concept_f1_macro_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy_shape = accuracy_score(\n",
    "            true_concepts[:, :3].reshape(-1), predicted_concepts[:, :3].reshape(-1)\n",
    "        )\n",
    "        concept_f1_macro_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape])\n",
    "        concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape])\n",
    "        concept_f1_micro = np.mean([concept_f1_micro_color, concept_f1_micro_shape])\n",
    "        concept_f1_weighted = np.mean(\n",
    "            [concept_f1_weighted_color, concept_f1_weighted_shape]\n",
    "        )\n",
    "    elif dataset_name in [\"clevr\"]:\n",
    "\n",
    "        mask_color = true_concepts[:, :, 0].reshape(-1) != -1\n",
    "        mask_shapes = true_concepts[:, :, 1].reshape(-1) != -1\n",
    "        mask_materials = true_concepts[:, :, 2].reshape(-1) != -1\n",
    "        mask_sizes = true_concepts[:, :, 3].reshape(-1) != -1\n",
    "\n",
    "        filtered_true_colors = true_concepts[:, :, 0].reshape(-1)[mask_color]\n",
    "        filtered_predicted_colors = predicted_concepts[:, :, 0].reshape(-1)[mask_color]\n",
    "\n",
    "        filtered_true_shapes = true_concepts[:, :, 1].reshape(-1)[mask_shapes]\n",
    "        filtered_predicted_shapes = predicted_concepts[:, :, 1].reshape(-1)[mask_shapes]\n",
    "\n",
    "        filtered_true_materials = true_concepts[:, :, 2].reshape(-1)[mask_materials]\n",
    "        filtered_predicted_materials = predicted_concepts[:, :, 2].reshape(-1)[mask_materials]\n",
    "\n",
    "        filtered_true_sizes = true_concepts[:, :, 3].reshape(-1)[mask_sizes]\n",
    "        filtered_predicted_sizes = predicted_concepts[:, :, 3].reshape(-1)[mask_sizes]\n",
    "        \n",
    "        concept_accuracy_color = accuracy_score(\n",
    "            filtered_true_colors, filtered_predicted_colors,\n",
    "        )\n",
    "        concept_f1_macro_color = f1_score(\n",
    "            filtered_true_colors, filtered_predicted_colors,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_color = f1_score(\n",
    "            filtered_true_colors, filtered_predicted_colors,\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_color = f1_score(\n",
    "            filtered_true_colors, filtered_predicted_colors,\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy_shape = accuracy_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "        )\n",
    "        concept_f1_macro_shape = f1_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_macro_shape = f1_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_shape = f1_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_shape = f1_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy_sizes = accuracy_score(\n",
    "            filtered_true_sizes, filtered_predicted_sizes,\n",
    "        )\n",
    "        concept_f1_macro_sizes = f1_score(\n",
    "            filtered_true_sizes, filtered_predicted_sizes,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_macro_sizes = f1_score(\n",
    "            filtered_true_sizes, filtered_predicted_sizes,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_sizes = f1_score(\n",
    "            filtered_true_sizes, filtered_predicted_sizes,\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_sizes = f1_score(\n",
    "            filtered_true_sizes, filtered_predicted_sizes,\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy_materials = accuracy_score(\n",
    "            filtered_true_materials, filtered_predicted_materials,\n",
    "        )\n",
    "        concept_f1_macro_materials = f1_score(\n",
    "            filtered_true_materials, filtered_predicted_materials,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_macro_materials = f1_score(\n",
    "            filtered_true_materials, filtered_predicted_materials,\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_materials = f1_score(\n",
    "            filtered_true_materials, filtered_predicted_materials,\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_materials = f1_score(\n",
    "            filtered_true_shapes, filtered_predicted_shapes,\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape, concept_accuracy_materials, concept_accuracy_sizes])\n",
    "        concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape, concept_f1_macro_materials, concept_f1_macro_sizes])\n",
    "        concept_f1_micro = np.mean([concept_f1_micro_color, concept_f1_micro_shape, concept_f1_micro_materials, concept_f1_micro_sizes])\n",
    "        concept_f1_weighted = np.mean(\n",
    "            [concept_f1_weighted_color, concept_f1_weighted_shape, concept_f1_weighted_materials, concept_f1_weighted_sizes]\n",
    "        )\n",
    "    else:\n",
    "        concept_accuracy = accuracy_score(true_concepts, predicted_concepts)\n",
    "        concept_f1_macro = f1_score(true_concepts, predicted_concepts, average=\"macro\")\n",
    "        concept_f1_micro = f1_score(true_concepts, predicted_concepts, average=\"micro\")\n",
    "        concept_f1_weighted = f1_score(\n",
    "            true_concepts, predicted_concepts, average=\"weighted\"\n",
    "        )\n",
    "\n",
    "    if multilabel_label:\n",
    "        for i in range(true_labels.shape[1]):\n",
    "            label_accuracy += accuracy_score(true_labels[i], predicted_labels[i])\n",
    "            label_f1_macro += f1_score(\n",
    "                true_labels[i], predicted_labels[i], average=\"macro\"\n",
    "            )\n",
    "            label_f1_micro += f1_score(\n",
    "                true_labels[i], predicted_labels[i], average=\"micro\"\n",
    "            )\n",
    "            label_f1_weighted += f1_score(\n",
    "                true_labels[i], predicted_labels[i], average=\"weighted\"\n",
    "            )\n",
    "\n",
    "        label_accuracy = label_accuracy / true_labels.shape[1]\n",
    "        label_f1_macro = label_f1_macro / true_labels.shape[1]\n",
    "        label_f1_micro = label_f1_micro / true_labels.shape[1]\n",
    "        label_f1_weighted = label_f1_weighted / true_labels.shape[1]\n",
    "    else:\n",
    "        label_accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "        label_f1_macro = f1_score(true_labels, predicted_labels, average=\"macro\")\n",
    "        label_f1_micro = f1_score(true_labels, predicted_labels, average=\"micro\")\n",
    "        label_f1_weighted = f1_score(true_labels, predicted_labels, average=\"weighted\")\n",
    "\n",
    "    if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipSDDOIA\"]:\n",
    "        metrics = BOIAMetrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            label_accuracy=label_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            label_f1_macro=label_f1_macro,\n",
    "            label_f1_micro=label_f1_micro,\n",
    "            label_f1_weighted=label_f1_weighted,\n",
    "            collapse=collapse,\n",
    "            collapse_hard=collapse_hard,\n",
    "            collapse_forward=collapse_forward,\n",
    "            collapse_stop=collapse_stop,\n",
    "            collapse_right=collapse_right,\n",
    "            collapse_left=collapse_left,\n",
    "            collapse_hard_forward=collapse_hard_forward,\n",
    "            collapse_hard_stop=collapse_hard_stop,\n",
    "            collapse_hard_right=collapse_hard_right,\n",
    "            collapse_hard_left=collapse_hard_left,\n",
    "            mean_collapse=mean_collapse,\n",
    "            mean_hard_collapse=mean_hard_collapse,\n",
    "            avg_nll=avg_nll,\n",
    "        )\n",
    "    elif dataset_name in [\"minikandinsky\", \"kandinsky\", \"clipkandinsky\"]:\n",
    "        metrics = KandMetrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            label_accuracy=label_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            label_f1_macro=label_f1_macro,\n",
    "            label_f1_micro=label_f1_micro,\n",
    "            label_f1_weighted=label_f1_weighted,\n",
    "            collapse=collapse,\n",
    "            collapse_hard=collapse_hard,\n",
    "            avg_nll=avg_nll,\n",
    "            collapse_shapes=collapse_shapes,\n",
    "            collapse_color=collapse_color,\n",
    "            collapse_hard_shapes=collapse_hard_shapes,\n",
    "            mean_collapse_hard=mean_collapse_hard,\n",
    "            mean_collapse=mean_collapse,\n",
    "            collapse_hard_color=collapse_hard_color,\n",
    "        )\n",
    "    elif dataset_name in [\"clevr\"]:\n",
    "        metrics = ClevrMetrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            label_accuracy=label_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            label_f1_macro=label_f1_macro,\n",
    "            label_f1_micro=label_f1_micro,\n",
    "            label_f1_weighted=label_f1_weighted,\n",
    "            collapse=0.0,\n",
    "            collapse_hard=0.0,\n",
    "            avg_nll=avg_nll,\n",
    "            collapse_shapes=collapse_shapes,\n",
    "            collapse_color=collapse_color,\n",
    "            collapse_materials=collapse_materials,\n",
    "            collapse_sizes=collapse_sizes,\n",
    "            collapse_hard_shapes=0.0,\n",
    "            collapse_hard_color=0.0,\n",
    "            collapse_hard_materials=0.0,\n",
    "            collapse_hard_sizes=0.0,\n",
    "            mean_collapse_hard=0.0,\n",
    "            mean_collapse=mean_collapse,\n",
    "        )\n",
    "    else:\n",
    "        metrics = Metrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            label_accuracy=label_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            label_f1_macro=label_f1_macro,\n",
    "            label_f1_micro=label_f1_micro,\n",
    "            label_f1_weighted=label_f1_weighted,\n",
    "            collapse=collapse,\n",
    "            collapse_hard=collapse_hard,\n",
    "            avg_nll=avg_nll,\n",
    "        )\n",
    "\n",
    "    if dataset_name in [\"shortmnist\", \"mnistAddition\"]:\n",
    "        plot_confusion_matrix(\n",
    "            true_concepts,\n",
    "            predicted_concepts,\n",
    "            classes=[i for i in range(10)],\n",
    "            normalize=True,\n",
    "            title=f\"{model_name}_{dataset_name}_{seed}.pdf\",\n",
    "            is_boia=True,\n",
    "        )\n",
    "    elif dataset_name in [\"boia\", \"sddoia\"]:\n",
    "\n",
    "        plot_confusion_matrix(\n",
    "            convert_to_categories(true_concepts[:, :3].astype(int)),\n",
    "            convert_to_categories(predicted_concepts[:, :3].astype(int)),\n",
    "            [\"\" for i in range(2**3)],\n",
    "            True,\n",
    "            f\"{model_name}_{dataset_name}_{seed}_forward.pdf\",\n",
    "        )\n",
    "        plot_confusion_matrix(\n",
    "            convert_to_categories(true_concepts[:, 3:9].astype(int)),\n",
    "            convert_to_categories(predicted_concepts[:, 3:9].astype(int)),\n",
    "            [\"\" for i in range(2**6)],\n",
    "            True,\n",
    "            f\"{model_name}_{dataset_name}_{seed}_stop.pdf\",\n",
    "        )\n",
    "        plot_confusion_matrix(\n",
    "            convert_to_categories(true_concepts[:, 9:15].astype(int)),\n",
    "            convert_to_categories(predicted_concepts[:, 9:15].astype(int)),\n",
    "            [\"\" for i in range(2**6)],\n",
    "            True,\n",
    "            f\"{model_name}_{dataset_name}_{seed}_left.pdf\",\n",
    "        )\n",
    "        plot_confusion_matrix(\n",
    "            convert_to_categories(true_concepts[:, 15:21].astype(int)),\n",
    "            convert_to_categories(predicted_concepts[:, 15:21].astype(int)),\n",
    "            [\"\" for i in range(2**6)],\n",
    "            True,\n",
    "            f\"{model_name}_{dataset_name}_{seed}_right.pdf\",\n",
    "        )\n",
    "    elif dataset_name in [\"kandinsky\", \"minikandinsky\"]:\n",
    "        plot_confusion_matrix(\n",
    "            true_concepts,\n",
    "            predicted_concepts,\n",
    "            classes=[i for i in range(10)],\n",
    "            normalize=True,\n",
    "            title=f\"{model_name}_{dataset_name}_{seed}.pdf\",\n",
    "        )\n",
    "    elif dataset_name in [\"clevr\"]:\n",
    "        # TODO\n",
    "        pass\n",
    "        # plot_confusion_matrix(\n",
    "        #     true_concepts,\n",
    "        #     predicted_concepts,\n",
    "        #     classes=[i for i in range(10)],\n",
    "        #     normalize=True,\n",
    "        #     title=f\"{model_name}_{dataset_name}_{seed}.pdf\",\n",
    "        # )\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the right dataset and the right model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(datasetname, args):\n",
    "    if datasetname.lower() == \"boia\":\n",
    "        return BOIA(args)\n",
    "    if datasetname.lower() == \"sddoia\":\n",
    "        return SDDOIA(args)\n",
    "    if datasetname.lower() == \"minikandinsky\":\n",
    "        return MiniKandinsky(args)\n",
    "    if datasetname.lower() == \"kandinsky\":\n",
    "        return Kandinsky(args)\n",
    "    if datasetname.lower() == \"shortmnist\":\n",
    "        return SHORTMNIST(args)\n",
    "    if datasetname.lower() == \"clipkandinsky\":\n",
    "        return CLIPKandinsky(args)\n",
    "    if datasetname.lower() == \"clipshortmnist\":\n",
    "        return CLIPSHORTMNIST(args)\n",
    "    if datasetname.lower() == \"clipboia\":\n",
    "        return CLIPBOIA(args)\n",
    "    if datasetname.lower() == \"clipSDDOIA\":\n",
    "        return CLIPSDDOIA(args)\n",
    "    if datasetname.lower() == \"addmnist\":\n",
    "        return ADDMNIST(args)\n",
    "    if datasetname.lower() == \"clevr\":\n",
    "        return CLEVR(args)\n",
    "\n",
    "    raise NotImplementedError(f\"Dataset {datasetname} missing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(modelname, encoder, args):\n",
    "    if modelname.lower() == \"boiadpl\":\n",
    "        return BoiaDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"SDDOIAdpl\":\n",
    "        return SDDOIADPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"boialtn\":\n",
    "        return BOIALTN(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"SDDOIAltn\":\n",
    "        return SDDOIALTN(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"boiann\":\n",
    "        return BOIAnn(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"SDDOIAnn\":\n",
    "        return SDDOIAnn(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"boiacbm\":\n",
    "        return BoiaCBM(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"SDDOIAcbm\":\n",
    "        return SDDOIACBM(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"minikanddpl\":\n",
    "        return MiniKandDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"kandltn\":\n",
    "        return KANDltn(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"kandnn\":\n",
    "        return KANDnn(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"kanddpl\":\n",
    "        return KandDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"kandcbm\":\n",
    "        return KandCBM(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"mnistdpl\":\n",
    "        return MnistDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"mnistdsl\":\n",
    "        return MnistDSL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"mnistltn\":\n",
    "        return MnistLTN(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"mnistnn\":\n",
    "        return MNISTnn(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"mnistcbm\":\n",
    "        return MNISTCBM()\n",
    "    if modelname.lower() == \"mnistdsldpl\":\n",
    "        return MnistDSLDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"clevrcbm\":\n",
    "        return ClevrCBM(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"clevrdsldpl\":\n",
    "        return ClevrDSLDPL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"clevrdsl\":\n",
    "        return ClevrDSL(encoder=encoder, args=args)\n",
    "    if modelname.lower() == \"clevrdpl\":\n",
    "        return CLEVRDPL(encoder=encoder, args=args)\n",
    "\n",
    "    raise NotImplementedError(f\"Model {modelname} missing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace(\n",
    "    backbone=\"conceptizer\",  #\n",
    "    preprocess=0,\n",
    "    finetuning=0,\n",
    "    batch_size=512,#256,\n",
    "    n_epochs=20,\n",
    "    validate=1,\n",
    "    dataset=\"boia\",\n",
    "    lr=0.001,\n",
    "    exp_decay=0.99,\n",
    "    warmup_steps=1,\n",
    "    wandb=None,\n",
    "    task=\"boia\",\n",
    "    boia_model=\"ce\",\n",
    "    model=\"boiadpl\",\n",
    "    c_sup=1,\n",
    "    which_c=[-1],\n",
    "    joint=False,\n",
    "    boia_ood_knowledge=True,\n",
    "    splitted=False,\n",
    "    eps_sym=0.5,\n",
    "    eps_rul=0.5\n",
    ")\n",
    "\n",
    "# get dataset\n",
    "dataset = get_dataset(args.dataset, args)\n",
    "# get model\n",
    "model = get_model(modelname=args.model, encoder=dataset.get_backbone()[0], args=args)\n",
    "\n",
    "# set cpu for the moment\n",
    "model.device = \"cuda:0\"\n",
    "\n",
    "model.to(model.device)\n",
    "if hasattr(model, \"encoder\"):\n",
    "    model.encoder.to(model.device)\n",
    "if hasattr(model, \"net\"):\n",
    "    model.net.to(model.device)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the seeds of the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [1011, 1213, 1415, 1617, 1819]\n",
    "\n",
    "model_path = f\"path"\n",
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loop through the dataset and retrive concepts and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_concepts_and_labels_boia(out_labels, out_concepts):\n",
    "    batch_size = out_labels.size(0)\n",
    "\n",
    "    predicted_labels, predicted_concepts = [], []\n",
    "\n",
    "    for idx_batch in range(batch_size):\n",
    "        prob_labels = torch.split(out_labels[idx_batch], 2)\n",
    "        prob_concepts = torch.split(out_concepts[idx_batch], 2)\n",
    "\n",
    "        tmp_lab, tmp_conc = [], []\n",
    "\n",
    "        for l_lab in prob_labels:\n",
    "            tmp_lab.append(torch.argmax(l_lab, dim=0))\n",
    "        for l_conc in prob_concepts:\n",
    "            tmp_conc.append(torch.argmax(l_conc, dim=0))\n",
    "\n",
    "        predicted_labels.append(torch.tensor([tmp_lab]))\n",
    "        predicted_concepts.append(torch.tensor([tmp_conc]))\n",
    "\n",
    "    predicted_labels = torch.concatenate(predicted_labels, dim=0)\n",
    "    predicted_concepts = torch.concatenate(predicted_concepts, dim=0)\n",
    "\n",
    "    return predicted_labels, predicted_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_concepts_and_labels_mnist(\n",
    "    out_labels, out_concepts, true_concepts, is_ood=False\n",
    "):\n",
    "\n",
    "    # filtering out the extended support\n",
    "    # if not is_ood:\n",
    "    #     for i in range(19):\n",
    "    #         if i in [6, 10, 12]:\n",
    "    #             continue\n",
    "    #         out_labels[:, i] = 0\n",
    "\n",
    "    predicted_labels = torch.argmax(out_labels, dim=-1)\n",
    "    predicted_concepts = torch.argmax(out_concepts, dim=-1)\n",
    "\n",
    "    predicted_concepts = predicted_concepts.view(predicted_concepts.numel())\n",
    "    refactored_true_concepts = true_concepts.view(true_concepts.numel())\n",
    "\n",
    "    return predicted_labels, predicted_concepts, refactored_true_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_concepts_and_labels_kand(out_labels, out_concepts, true_concepts):\n",
    "\n",
    "    # take the prediction\n",
    "    predicted_labels = torch.argmax(out_labels, dim=1)\n",
    "\n",
    "    # stack colors and shapes on top of each other\n",
    "    refactored_true_concepts = torch.split(true_concepts, 1, dim=1)\n",
    "    refactored_true_concepts = torch.concatenate(\n",
    "        refactored_true_concepts, dim=0\n",
    "    ).squeeze(1)\n",
    "\n",
    "    # take the prediction\n",
    "    predicted_concepts_list = torch.split(out_concepts, 3, dim=2)\n",
    "    predicted_concepts = []\n",
    "    # take the argmax\n",
    "    for pc in predicted_concepts_list:\n",
    "        predicted_concepts.append(torch.argmax(pc, dim=2))\n",
    "    predicted_concepts = torch.stack(predicted_concepts, dim=2)\n",
    "\n",
    "    # make them the same dimension as the groundtruth\n",
    "    predicted_concepts = torch.split(predicted_concepts, 1, dim=1)\n",
    "    predicted_concepts = torch.concatenate(predicted_concepts, dim=0).squeeze(1)\n",
    "\n",
    "    return predicted_labels, torch.squeeze(predicted_concepts), refactored_true_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_concepts_and_labels_clevr(out_dict, true_concepts, is_dsl):\n",
    "    # take the prediction\n",
    "    if not is_dsl:\n",
    "        predicted_labels = torch.argmax(out_dict[\"YS\"], dim=1)\n",
    "    else:\n",
    "        predicted_labels = out_dict[\"PRED\"]\n",
    "\n",
    "    refactored_true_concepts = true_concepts.view(true_concepts.shape[0], 4, -1)\n",
    "    pCS = out_dict['pCS'].view(out_dict['pCS'].shape[0], 4, -1)\n",
    "\n",
    "    # get out gt and pt\n",
    "\n",
    "    def conditional_argmax(tensor):\n",
    "        max_vals, argmax_vals = torch.max(tensor, dim=-1)  # Get max values and indices\n",
    "        argmax_vals[max_vals == -1] = -1  # Set argmax to -1 if max value is -1\n",
    "        return argmax_vals\n",
    "\n",
    "    gt_colors, gt_shapes, gt_materials, gt_sizes = refactored_true_concepts[:, :, :8], refactored_true_concepts[:, :, 8:11], refactored_true_concepts[:, :, 11:13], refactored_true_concepts[:, :, 13:15]\n",
    "    pt_colors, pt_shapes, pt_materials, pt_sizes = pCS[:, :, :8], pCS[:, :, 8:11], pCS[:, :, 11:13], pCS[:, :, 13:15]\n",
    "\n",
    "    gt_colors, gt_shapes, gt_materials, gt_sizes = (\n",
    "        conditional_argmax(gt_colors),\n",
    "        conditional_argmax(gt_shapes),\n",
    "        conditional_argmax(gt_materials),\n",
    "        conditional_argmax(gt_sizes)\n",
    "    )\n",
    "\n",
    "    pt_colors, pt_shapes, pt_materials, pt_sizes = (\n",
    "        conditional_argmax(pt_colors),\n",
    "        conditional_argmax(pt_shapes),\n",
    "        conditional_argmax(pt_materials),\n",
    "        conditional_argmax(pt_sizes)\n",
    "    )\n",
    "\n",
    "    refactored_true_concepts = torch.stack(\n",
    "        [gt_colors, gt_shapes, gt_materials, gt_sizes],\n",
    "        dim = -1,\n",
    "    )\n",
    "    predicted_concepts = torch.stack(\n",
    "        [pt_colors, pt_shapes, pt_materials, pt_sizes],\n",
    "        dim = -1\n",
    "    )\n",
    "\n",
    "    return predicted_labels, predicted_concepts, refactored_true_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrive_concepts_and_labels(model, dataset, dataset_name, model_name, is_ood=False, is_dsl=False):\n",
    "\n",
    "    true_labels, predicted_labels, true_concepts, predicted_concepts = [], [], [], []\n",
    "\n",
    "    nll_loss = 0.0\n",
    "    criterion = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
    "\n",
    "    for i, data in enumerate(tqdm(dataset)):\n",
    "        images, labels, concepts = data\n",
    "        images, labels, concepts = (\n",
    "            images.to(model.device),\n",
    "            labels.to(model.device),\n",
    "            concepts.to(model.device),\n",
    "        )\n",
    "\n",
    "        # filtering out the middle rules supervision\n",
    "        if dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "            labels = labels[:, -1]\n",
    "\n",
    "        if is_dsl:\n",
    "            out_dict = model(images, eval=True)\n",
    "        else:\n",
    "            out_dict = model(images)\n",
    "\n",
    "        out_label, out_concept = None, None\n",
    "\n",
    "        if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipSDDOIA\"]:\n",
    "            class_predictions = torch.split(out_dict[\"YS\"], 2, dim=1)\n",
    "            assert len(class_predictions) == 4\n",
    "\n",
    "            loss = 0\n",
    "            for i, _pred in enumerate(class_predictions):\n",
    "                loss += criterion(_pred.float().cpu(), labels[:, i].long().cpu())\n",
    "            loss /= len(class_predictions)\n",
    "        else:\n",
    "            if model_name in [\"mnistdsl\", \"mnistdsldpl\"]:\n",
    "                c1 = torch.argmax(out_dict[\"pCS\"][:, 0, :], dim=-1)\n",
    "                c2 = torch.argmax(out_dict[\"pCS\"][:, 1, :], dim=-1)\n",
    "                pred_y = out_dict[\"KNOWLEDGE\"][c1, c2].float() + 1e-6\n",
    "                Z = torch.sum(pred_y, dim=1, keepdim=True)\n",
    "                pred_y /= Z\n",
    "                loss = torch.nn.functional.nll_loss(pred_y.log().cpu(), labels.long().cpu(), reduction=\"sum\")\n",
    "            else:\n",
    "                pred_y = out_dict[\"YS\"].float() + 1e-6\n",
    "                Z = torch.sum(pred_y, dim=1, keepdim=True)\n",
    "                pred_y /= Z\n",
    "                loss = torch.nn.functional.nll_loss(pred_y.log().cpu(), labels.long().cpu(), reduction=\"sum\")\n",
    "\n",
    "        nll_loss += loss.item()\n",
    "        # print(nll_loss)\n",
    "\n",
    "        if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipSDDOIA\"]:\n",
    "            out_label, out_concept = get_concepts_and_labels_boia(\n",
    "                out_dict[\"YS\"], out_dict[\"pCS\"]\n",
    "            )\n",
    "        elif dataset_name in [\"shortmnist\", \"clipshortmnist\", \"addmnist\"]:\n",
    "            if is_dsl:\n",
    "                _, out_concept, concepts = get_concepts_and_labels_mnist(\n",
    "                    out_dict[\"PRED\"], out_dict[\"pCS\"], concepts, is_ood\n",
    "                )\n",
    "                out_label = out_dict[\"PRED\"].cpu().squeeze()\n",
    "            else:\n",
    "                out_label, out_concept, concepts = get_concepts_and_labels_mnist(\n",
    "                    out_dict[\"YS\"], out_dict[\"pCS\"], concepts, is_ood\n",
    "                )\n",
    "        elif dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "            out_label, out_concept, concepts = get_concepts_and_labels_kand(\n",
    "                out_dict[\"YS\"], out_dict[\"pCS\"], concepts\n",
    "            )\n",
    "        elif dataset_name in [\"clevr\"]:\n",
    "            out_label, out_concept, concepts = get_concepts_and_labels_clevr(\n",
    "                out_dict, concepts, is_dsl\n",
    "            )\n",
    "\n",
    "        true_labels.append(labels.cpu().numpy())\n",
    "        true_concepts.append(concepts.cpu().numpy())\n",
    "\n",
    "        predicted_labels.append(out_label.detach().cpu().numpy())\n",
    "        predicted_concepts.append(out_concept.cpu().numpy())\n",
    "\n",
    "        # break # REMOVEME\n",
    "\n",
    "    # concatenate\n",
    "    true_labels = np.concatenate(true_labels, axis=0)\n",
    "    predicted_labels = np.concatenate(predicted_labels, axis=0)\n",
    "    true_concepts = np.concatenate(true_concepts, axis=0)\n",
    "    predicted_concepts = np.concatenate(predicted_concepts, axis=0)\n",
    "\n",
    "    print(nll_loss, len(dataset.dataset))\n",
    "    avg_nll = nll_loss / len(dataset.dataset)\n",
    "\n",
    "    assert true_labels.shape == predicted_labels.shape\n",
    "    assert true_concepts.shape == predicted_concepts.shape, f\"{true_concepts.shape} {predicted_concepts.shape}\"\n",
    "\n",
    "    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(\n",
    "    model, test_set, dataset_name, model_name, ood_set=None, ood_set_2=None, hungarian=False, train_set=None, is_dsl=False\n",
    "):  # TODO: define attributes\n",
    "\n",
    "    # List of metics\n",
    "    in_metrics_list = []\n",
    "    ood_metrics_list = []\n",
    "    ood_metrics_2_list = []\n",
    "\n",
    "    n_files = 0\n",
    "\n",
    "    # Loop through seeds\n",
    "    for seed in seeds:\n",
    "        print(\"Doing\", seed, \"...\")\n",
    "\n",
    "        to_add = \"\"\n",
    "        if \"cbm\" in model_path:\n",
    "            to_add = \"_multi_linear\"\n",
    "\n",
    "\n",
    "        print(\"TO ADD:\", to_add)\n",
    "\n",
    "        if \"cbm\" in model_path: \n",
    "            current_model_path = f\"{model_path}_{seed}{to_add}.pth\"\n",
    "        else:\n",
    "            current_model_path = f\"{model_path}{seed}.pth\"\n",
    "\n",
    "        # current_model_path = f\"{model_path}{seed}.pth\"\n",
    "        print(current_model_path)\n",
    "\n",
    "        if not os.path.exists(current_model_path):\n",
    "            print(f\"{current_model_path} is missing...\")\n",
    "            continue\n",
    "        else:\n",
    "            print(f\"Loading {current_model_path}...\")\n",
    "\n",
    "        n_files += 1\n",
    "\n",
    "        try:\n",
    "            # retrieve the status dict\n",
    "            model_state_dict = torch.load(current_model_path)\n",
    "            # Load the model status dict\n",
    "            model.load_state_dict(model_state_dict)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "\n",
    "        if dataset_name == \"shortmnist\":\n",
    "            model = model.float()\n",
    "\n",
    "        model.eval()\n",
    "\n",
    "        w_acc, w_f1 = None, None\n",
    "\n",
    "        if hungarian:\n",
    "            # TODO\n",
    "\n",
    "\n",
    "            pi = get_hungarian_permutation(model, train_set, dataset_name, model_name, metric=\"correlation\", is_dsl=is_dsl)\n",
    "            ind_data = retrive_concepts_and_labels_hungarian(model, pi, test_set, dataset_name, model_name, is_dsl=is_dsl)\n",
    "            if model_name not in [\"mnistdpl\", \"clevrdpl\", \"boiadpl\"]:\n",
    "                if dataset_name in [\"addmnist\"]:\n",
    "\n",
    "                    if \"cbm\" in model_name:\n",
    "                        w = get_cbm_knowledge(model.fc_aggregate, model.device, dataset_name)\n",
    "                    else:\n",
    "                        w = torch.argmax(torch.nn.functional.softmax(model.weights, dim=2), dim=2)\n",
    "                    w_aligned = np.dot(pi.T, np.dot(w.cpu().numpy(), pi)).flatten()\n",
    "                    w_gt = get_gt_knowledge(dataset_name).flatten()\n",
    "                    w_acc = accuracy_score(w_gt, w_aligned)\n",
    "                    w_f1 = f1_score(w_gt, w_aligned, average=\"macro\")\n",
    "                elif dataset_name in [\"boia\"]:\n",
    "                    w_acc, w_f1 = evaluate_knowledge_boia(model, pi, model_name)\n",
    "                else:\n",
    "                    w_acc, w_f1 = evaluate_knowledge_clevr(model, pi, ind_data)\n",
    "        else:\n",
    "            ind_data = retrive_concepts_and_labels(model, test_set, dataset_name, model_name, is_dsl=is_dsl)\n",
    "\n",
    "        if ood_set is not None:\n",
    "            out_data = retrive_concepts_and_labels(\n",
    "                model, ood_set, dataset_name, model_name, is_ood=True\n",
    "            )\n",
    "\n",
    "        if ood_set_2 is not None:\n",
    "            out_data_2 = retrive_concepts_and_labels(\n",
    "                model, ood_set_2, dataset_name, model_name, is_ood=True\n",
    "            )\n",
    "\n",
    "        in_metrics = compute_metrics(*ind_data, dataset_name, model_name, seed)\n",
    "        if w_acc is not None and w_f1 is not None:\n",
    "            if dataset_name == \"boia\":\n",
    "                in_metrics = BOIAExtendedMetrics.fromMetric(in_metrics, w_f1, w_acc)\n",
    "            else:\n",
    "                in_metrics = ExtendedMetrics.fromMetric(in_metrics, w_f1, w_acc)\n",
    "        in_metrics_list.append(in_metrics)\n",
    "\n",
    "        if ood_set is not None:\n",
    "            ood_metrics = compute_metrics(*out_data, dataset_name, model_name, seed)\n",
    "            ood_metrics_list.append(ood_metrics)\n",
    "\n",
    "        if ood_set_2 is not None:\n",
    "            ood_metrics_2 = compute_metrics(*out_data_2, dataset_name, model_name, seed)\n",
    "            ood_metrics_2_list.append(ood_metrics_2)\n",
    "\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    if n_files == 1:\n",
    "        print(\"IN\", in_metrics.to_string())\n",
    "        print(\"OOD\", ood_metrics.to_string())\n",
    "\n",
    "    assert n_files > 1, \"At least 2 files to compare\"\n",
    "\n",
    "    # Compute standard deviation for each metric\n",
    "    print(in_metrics_list)\n",
    "    for key in vars(in_metrics_list[0]):  # the key are always the same\n",
    "        # skip hidden elements\n",
    "        if not key.startswith(\"_\"):\n",
    "            # retrieve the list of values\n",
    "            in_metric_values = [getattr(metrics, key) for metrics in in_metrics_list]\n",
    "            ood_metric_values = [getattr(metrics, key) for metrics in ood_metrics_list]\n",
    "            ood_metric_2_values = [\n",
    "                getattr(metrics, key) for metrics in ood_metrics_2_list\n",
    "            ]\n",
    "\n",
    "            # convert lists to NumPy arrays\n",
    "            in_metric_values_arr = np.array(in_metric_values)\n",
    "            ood_metric_values_arr = np.array(ood_metric_values)\n",
    "            ood_metric_values_2_arr = np.array(ood_metric_2_values)\n",
    "\n",
    "            # Compute the standard deviation\n",
    "            in_metric_std_dev = np.std(in_metric_values_arr)\n",
    "            ood_metric_std_dev = np.std(ood_metric_values_arr)\n",
    "            ood_metric_2_std_dev = np.std(ood_metric_values_2_arr)\n",
    "\n",
    "            # Compute the mean\n",
    "            in_metric_std_mean = np.mean(in_metric_values_arr)\n",
    "            ood_metric_std_mean = np.mean(ood_metric_values_arr)\n",
    "            ood_metric_2_std_mean = np.mean(ood_metric_values_2_arr)\n",
    "\n",
    "            print(\n",
    "                \"\\n{} (In): ${:.2f} \\pm {:.2f}$\".format(\n",
    "                    key.replace(\"_\", \" \").title(),\n",
    "                    round(in_metric_std_mean, 2),\n",
    "                    round(in_metric_std_dev, 2),\n",
    "                )\n",
    "            )\n",
    "\n",
    "            if ood_set is not None:\n",
    "                print(\n",
    "                    \"{} (OOD): ${:.2f} \\pm {:.2f}$\".format(\n",
    "                        key.replace(\"_\", \" \").title(),\n",
    "                        round(ood_metric_std_mean, 2),\n",
    "                        round(ood_metric_std_dev, 2),\n",
    "                    )\n",
    "                )\n",
    "\n",
    "            if ood_set_2 is not None:\n",
    "                print(\n",
    "                    \"{} (OOD 2): ${:.2f} \\pm {:.2f}$\".format(\n",
    "                        key.replace(\"_\", \" \").title(),\n",
    "                        round(ood_metric_2_std_mean, 2),\n",
    "                        round(ood_metric_2_std_dev, 2),\n",
    "                    )\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hungarian Gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "\n",
    "def rearrange_predictions_with_confusion_clevr(pred, gt):\n",
    "    N, _, _ = pred.shape\n",
    "\n",
    "    global_cost_matrix = np.zeros((4, 4))\n",
    "\n",
    "    for i in range(N):\n",
    "\n",
    "        local_cost_matrix = np.zeros((4, 4))\n",
    "        for j in range(4):\n",
    "            for k in range(4):\n",
    "                matches = (pred[i, :, j] == gt[i, :, k]) & (gt[i, :, k] != -1)\n",
    "                local_cost_matrix[j, k] = matches.sum()\n",
    "\n",
    "        global_cost_matrix += local_cost_matrix\n",
    "\n",
    "    row_ind, col_ind = linear_sum_assignment(-global_cost_matrix)\n",
    "    pred_rearranged = np.copy(pred)\n",
    "    for i in range(N):\n",
    "        pred_rearranged[i] = pred[i, :, col_ind]\n",
    "\n",
    "    return pred_rearranged, col_ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hungarian_permutation(model, dataset, dataset_name, model_name, metric=\"correlation\", is_dsl=False):\n",
    "    _, _, true_concepts, predicted_concepts, _ = retrive_concepts_and_labels(model, dataset, dataset_name, model_name, False, is_dsl=is_dsl)\n",
    "\n",
    "    if dataset_name == \"clevr\":\n",
    "        _, perm_idx = rearrange_predictions_with_confusion_clevr(predicted_concepts, true_concepts)\n",
    "        \n",
    "        perm_color = permutation_matrix_from_predictions(\n",
    "            predicted_concepts[:, :, 0].flatten(), true_concepts[:, :, 0].flatten(), 8\n",
    "        ).numpy()\n",
    "        perm_shapes = permutation_matrix_from_predictions(\n",
    "            predicted_concepts[:, :, 1].flatten(), true_concepts[:, :, 1].flatten(), 3\n",
    "        ).numpy()\n",
    "        perm_material = permutation_matrix_from_predictions(\n",
    "            predicted_concepts[:, :, 2].flatten(), true_concepts[:, :, 2].flatten(), 2\n",
    "        ).numpy()\n",
    "        perm_sizes = permutation_matrix_from_predictions(\n",
    "            predicted_concepts[:, :, 3].flatten(), true_concepts[:, :, 3].flatten(), 2\n",
    "        ).numpy()\n",
    "\n",
    "        return (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes)\n",
    "    elif dataset_name == \"boia\":\n",
    "        return find_boia_permutation(\n",
    "            predicted_concepts, true_concepts, 21\n",
    "        )\n",
    "    else:\n",
    "        n_classes = 10 if dataset_name == \"addmnist\" else 2\n",
    "\n",
    "        return permutation_matrix_from_predictions(\n",
    "            predicted_concepts, true_concepts, n_classes\n",
    "        ).numpy()\n",
    "\n",
    "def retrive_concepts_and_labels_hungarian(model, perm_matrix, dataset, dataset_name, model_name, is_dsl=False):\n",
    "    true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll = retrive_concepts_and_labels(model, dataset, dataset_name, model_name, False, is_dsl=is_dsl)\n",
    "    \n",
    "    if dataset_name == \"clevr\":\n",
    "        (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes) = perm_matrix\n",
    "\n",
    "        for i in range(predicted_concepts.shape[0]):\n",
    "            predicted_concepts[i] = predicted_concepts[i, :, perm_idx]\n",
    "\n",
    "        predicted_colors = perm_color[predicted_concepts[:, :, 0]]\n",
    "        predicted_colors = np.argmax(predicted_colors, axis=-1)\n",
    "\n",
    "        predicted_shapes = perm_shapes[predicted_concepts[:, :, 1]]\n",
    "        predicted_shapes = np.argmax(predicted_shapes, axis=-1)\n",
    "\n",
    "        predicted_materials = perm_material[predicted_concepts[:, :, 2]]\n",
    "        predicted_materials = np.argmax(predicted_materials, axis=-1)\n",
    "\n",
    "        predicted_sizes = perm_sizes[predicted_concepts[:, :, 3]]\n",
    "        predicted_sizes = np.argmax(predicted_sizes, axis=-1)\n",
    "\n",
    "        predicted_concepts = np.stack(\n",
    "            [predicted_colors, predicted_shapes, predicted_materials, predicted_sizes],\n",
    "            axis=-1\n",
    "        )\n",
    "    elif dataset_name == \"boia\":\n",
    "        predicted_concepts = predicted_concepts[:, perm_matrix]\n",
    "    else:\n",
    "\n",
    "        predicted_concepts = perm_matrix[predicted_concepts]\n",
    "        predicted_concepts = np.argmax(predicted_concepts, axis=1)\n",
    "\n",
    "    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gt_knowledge(dataset_name):\n",
    "    if dataset_name == \"addmnist\":\n",
    "        w = []\n",
    "        for i in range(10):\n",
    "            for j in range(10):\n",
    "                w.append((i + j) % 2)\n",
    "        return np.array(w)\n",
    "    else: \n",
    "        pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cbm_knowledge(w, device, dataset_name):\n",
    "    knowledge = torch.zeros((10, 10))\n",
    "    for i in range(10):\n",
    "        for j in range(10):\n",
    "            x = torch.nn.functional.one_hot(torch.tensor([i]), num_classes=10).float().to(device)\n",
    "            y = torch.nn.functional.one_hot(torch.tensor([j]), num_classes=10).float().to(device)\n",
    "            xy = torch.cat([x, y], dim=-1)\n",
    "            # xy = x.unsqueeze(2).multiply(y.unsqueeze(1)).view(x.shape[0], -1)\n",
    "            knowledge[i, j] = torch.argmax(w(xy), dim=-1)\n",
    "    return knowledge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad(tensor, target_size = 8):\n",
    "    current_size = tensor.size(1)  # Get the current size of the second dimension\n",
    "    padding_size = max(0, target_size - current_size)  # Calculate how much padding is needed\n",
    "\n",
    "    # Pad with zeros if needed\n",
    "    if padding_size > 0:\n",
    "        padded_tensor = torch.nn.functional.pad(tensor, (0, padding_size))\n",
    "    else:\n",
    "        padded_tensor = tensor[:, :target_size]\n",
    "    return padded_tensor\n",
    "\n",
    "def clevr_logic(vector):\n",
    "    class_1_found = {'large_cube': False, 'large_cylinder': False}\n",
    "    class_2_found = {'small_metal_cube': False, 'small_sphere': False}\n",
    "    class_3_found = {'large_blue_sphere': False, 'small_yellow_sphere': False}\n",
    "\n",
    "    for obj in vector:\n",
    "        presence, color, shape, material, size = obj\n",
    "        \n",
    "        if presence == 0:\n",
    "            continue\n",
    "\n",
    "        colors = [\"gray\", \"red\", \"blue\", \"green\", \"brown\", \"purple\", \"cyan\", \"yellow\"]\n",
    "        shapes = [\"cube\", \"sphere\", \"cylinder\"]\n",
    "        materials = [\"rubber\", \"metal\"]\n",
    "        sizes = [\"large\", \"small\"]\n",
    "\n",
    "        color = colors[color]\n",
    "        shape = shapes[shape]\n",
    "        material = materials[material]\n",
    "        size = sizes[size]\n",
    "\n",
    "        if size == 'large' and shape == 'cube' and color == 'gray':\n",
    "            class_1_found['large_cube'] = True\n",
    "        if size == 'large' and shape == 'cylinder':\n",
    "            class_1_found['large_cylinder'] = True\n",
    "\n",
    "        if size == 'small' and material == 'metal' and shape == 'cube':\n",
    "            class_2_found['small_metal_cube'] = True\n",
    "        if size == 'small' and shape == 'sphere' and material == 'metal':\n",
    "            class_2_found['small_sphere'] = True\n",
    "\n",
    "        # Check for Class 3 objects\n",
    "        if size == 'large' and color == 'blue' and shape == 'sphere':\n",
    "            class_3_found['large_blue_sphere'] = True\n",
    "        if size == 'small' and color == 'yellow' and shape == 'sphere':\n",
    "            class_3_found['small_yellow_sphere'] = True\n",
    "\n",
    "    class_1 = all(class_1_found.values())\n",
    "    class_2 = all(class_2_found.values())\n",
    "    class_3 = all(class_3_found.values())\n",
    "\n",
    "    if sum([class_1, class_2, class_3]) == 1:\n",
    "        if class_1:\n",
    "            return 0\n",
    "        elif class_2:\n",
    "            return 1\n",
    "        elif class_3:\n",
    "            return 2\n",
    "    return 3 # no found or not interesting\n",
    "\n",
    "def get_concepts_label():\n",
    "    samples, labels = [], []\n",
    "\n",
    "    for _ in range(100):\n",
    "        presence = np.random.randint(0, 2, size=4)\n",
    "        color = np.random.randint(0, 8, size=4)\n",
    "        shape = np.random.randint(0, 3, size=4)\n",
    "        material = np.random.randint(0, 2, size=4)\n",
    "        size = np.random.randint(0, 2, size=4)\n",
    "        \n",
    "        objects = np.stack([presence, color, shape, material, size], axis=-1)\n",
    "        y = clevr_logic(objects)\n",
    "        labels.append(torch.tensor([y]).to(\"cpu\").long())\n",
    "\n",
    "        logits = []\n",
    "        for obj in objects:\n",
    "            l = []\n",
    "            presence, color, shape, material, size = torch.tensor(obj)\n",
    "\n",
    "            l.append(torch.tensor([presence]).to(\"cpu\").float())\n",
    "            l.append(torch.nn.functional.one_hot(color, 8).to(\"cpu\").float())\n",
    "            l.append(torch.nn.functional.one_hot(shape, 3).to(\"cpu\").float())\n",
    "            l.append(torch.nn.functional.one_hot(material, 2).to(\"cpu\").float())\n",
    "            l.append(torch.nn.functional.one_hot(size, 2).to(\"cpu\").float())\n",
    "\n",
    "            l = torch.cat(l, dim=0)\n",
    "            logits.append(l)\n",
    "        logits = torch.stack(logits, dim=1)\n",
    "        samples.append(logits)\n",
    "\n",
    "    samples = torch.stack(samples, dim=0)\n",
    "    labels = torch.cat(labels, dim=0)\n",
    "    return samples, labels\n",
    "\n",
    "def evaluate_knowledge_clevr(model, pi, indata):\n",
    "    samples, labels = get_concepts_label()\n",
    "\n",
    "    perm_idx, perm_color, perm_shapes, perm_material, perm_sizes = pi\n",
    "    perm_idx, perm_color, perm_shapes, perm_material, perm_sizes = perm_idx.astype(int), perm_color.astype(int) , perm_shapes.astype(int) , perm_material.astype(int) , perm_sizes.astype(int) \n",
    "\n",
    "    true_concepts = torch.tensor(samples)\n",
    "    full_concept_vector = torch.zeros((true_concepts.shape[0], 4, 16), dtype=float)\n",
    "\n",
    "    for idx in range(4):\n",
    "        concept_img = true_concepts[:, idx, :]\n",
    "        mask = (concept_img == -1).all(dim=1)\n",
    "        mask = ~mask\n",
    "\n",
    "        concept_vector = torch.zeros((concept_img.shape[0], 16), dtype=float)\n",
    "    \n",
    "        colors = concept_img[:, 0].to(int)\n",
    "        shapes = concept_img[:, 1].to(int)\n",
    "        materials = concept_img[:, 2].to(int)\n",
    "        sizes = concept_img[:, 3].to(int)\n",
    "        \n",
    "        if mask.sum() != 0:\n",
    "            concept_vector[mask, 0] = torch.tensor(mask, dtype=float)\n",
    "            concept_vector[mask, 1:9] = torch.tensor(perm_color.T[colors[mask]], dtype=float)\n",
    "            concept_vector[mask, 9:12] = torch.tensor(perm_shapes.T[shapes[mask]], dtype=float)\n",
    "            concept_vector[mask, 12:14] = torch.tensor(perm_material.T[materials[mask]], dtype=float)\n",
    "            concept_vector[mask, 14:] = torch.tensor(perm_sizes.T[sizes[mask]], dtype=float)\n",
    "        else:\n",
    "            concept_vector[:, 0] = torch.tensor(mask, dtype=float)\n",
    "            \n",
    "        full_concept_vector[:, idx, :] = concept_vector\n",
    "        \n",
    "    y = model.get_pred_from_prob(full_concept_vector.to(model.device).to(torch.float32), True).detach().cpu().numpy()\n",
    "    y = np.argmax(y, axis=-1)\n",
    "    \n",
    "    return accuracy_score(labels.numpy(), y), f1_score(labels.numpy(), y, average=\"macro\")\n",
    "\n",
    "def evaluate_knowledge_boia(model, pi, model_name):\n",
    "    def invert_permutation(permutation):\n",
    "        inv = np.zeros_like(permutation)\n",
    "        inv[permutation] = np.arange(len(permutation))\n",
    "        return inv\n",
    "    \n",
    "    sampled_configurations, labels = sample_boia_config()\n",
    "\n",
    "    inv_perm = invert_permutation(pi)\n",
    "    predicted = sampled_configurations[:, inv_perm]\n",
    "\n",
    "    y = model.get_pred_from_prob(predicted.to(model.device).to(torch.float32), False).detach().cpu()\n",
    "\n",
    "    if 'cbm' in model_name:\n",
    "        y = (y > 0.5).long()\n",
    "    else:\n",
    "        y_pred_split = torch.split(y, 2, dim=1)\n",
    "        y = torch.stack([pred.argmax(dim=1) for pred in y_pred_split], dim=1)\n",
    "\n",
    "    acc = accuracy_score(labels.numpy().flatten(), y.numpy().flatten())\n",
    "    f1 = f1_score(labels.numpy().flatten(), y.numpy().flatten(), average=\"macro\")\n",
    "    \n",
    "    return acc, f1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run all the things"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get loaders\n",
    "train_loader, val_loader, test_loader = dataset.get_data_loaders()\n",
    "# Get ood set if it exists\n",
    "ood_loader = getattr(dataset, \"ood_loader\", None)\n",
    "# ood_ambulance = getattr(\n",
    "# dataset, \"ood_loader_2\", None) # getattr(dataset, \"ood_loader_ambulance\", None)\n",
    "\n",
    "# Evaluate\n",
    "evaluate(\n",
    "    model,\n",
    "    test_loader,\n",
    "    args.dataset,\n",
    "    model_name=args.model,\n",
    "    ood_set=ood_loader,\n",
    "    ood_set_2=None,\n",
    "    hungarian=True,#True,\n",
    "    train_set=train_loader,\n",
    "    is_dsl=False#True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# model = MNISTCBM()\n",
    "\n",
    "# model.load_state_dict(torch.load(f\"{model_path}_{1011}_True.pth\"))\n",
    "\n",
    "# layer_weights = model.fc_aggregate[0].state_dict()['weight']\n",
    "\n",
    "# torch.save(layer_weights, \"linganguliguliguli_linear.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
