{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Library import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../../../../\")\n",
    "from utils.train import convert_to_categories, compute_coverage\n",
    "from datasets.boia import BOIA\n",
    "from datasets.sddoia import SDDOIA\n",
    "from datasets.minikandinsky import MiniKandinsky\n",
    "from datasets.shortcutmnist import SHORTMNIST\n",
    "from datasets.kandinsky import Kandinsky\n",
    "from datasets.clipshortcutmnist import CLIPSHORTMNIST\n",
    "from datasets.clipboia import CLIPBOIA\n",
    "from datasets.clipsddoia import CLIPSDDOIA\n",
    "from datasets.clipkandinsky import CLIPKandinsky\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from argparse import Namespace\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the numpy file to estimate the predicted concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_and_compute_mean(filepaths):\n",
    "    matrices = [np.load(filepath) for filepath in filepaths]\n",
    "\n",
    "    # Check if all matrices have the same shape\n",
    "    shape = matrices[0].shape\n",
    "    for matrix in matrices:\n",
    "        if matrix.shape != shape:\n",
    "            raise ValueError(\"All matrices must have the same shape\")\n",
    "\n",
    "    mean_matrix = np.mean(matrices, axis=0)\n",
    "\n",
    "    return mean_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_two_largest_indices(carr, dataset_name):\n",
    "\n",
    "    if \"clip\" in dataset_name:\n",
    "        tensor_part1 = carr[:, :10]\n",
    "        tensor_part2 = carr[:, 10:]\n",
    "\n",
    "        argmax_part1 = np.argmax(tensor_part1, axis=1)\n",
    "        argmax_part2 = np.argmax(tensor_part2, axis=1)\n",
    "\n",
    "        top_two_indices = np.stack((argmax_part1, argmax_part2), axis=1)\n",
    "    else:\n",
    "        sorted_indices = np.argsort(carr, axis=1)[:, ::-1]\n",
    "        top_two_indices = sorted_indices[:, :2]\n",
    "\n",
    "    return top_two_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_concepts_kandinsky(carr):\n",
    "    split_18 = np.split(carr, 3, axis=1)  # 3 x [100, 18]\n",
    "\n",
    "    argmax_results = []\n",
    "\n",
    "    for part_18 in split_18:\n",
    "        split_6 = np.split(part_18, 3, axis=1)  # 3 x [100, 6]\n",
    "\n",
    "        shapes = []\n",
    "        colors = []\n",
    "\n",
    "        for part_6 in split_6:\n",
    "            # Shape argmax\n",
    "            argmax_first3 = np.argmax(part_6[:, :3], axis=1)  # 100\n",
    "            shapes.append(argmax_first3)\n",
    "\n",
    "            # Color argmax\n",
    "            argmax_last3 = np.argmax(part_6[:, 3:], axis=1)  # 100\n",
    "            colors.append(argmax_last3)\n",
    "\n",
    "        # stack shaeps\n",
    "        shapes = np.stack(shapes, axis=1)  # 100, 3\n",
    "        colors = np.stack(colors, axis=1)  # 100, 3\n",
    "\n",
    "        # Concatenated argmax\n",
    "        pred_image = np.concatenate((shapes, colors), axis=1)  # [100, 6]\n",
    "        argmax_results.append(pred_image)\n",
    "\n",
    "    return np.stack(argmax_results, axis=1)  # [100, 3, 6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_concepts(mean_matrix, dataset_name):\n",
    "    if dataset_name in [\"shortmnist\", \"clipshortmnist\"]:\n",
    "        return find_two_largest_indices(mean_matrix, dataset_name)\n",
    "    elif dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipsddoia\"]:\n",
    "        return (mean_matrix > 0).astype(float)\n",
    "    elif dataset_name in [\"kandinsky\", \"clipkandinsky\"]:\n",
    "        return find_concepts_kandinsky(mean_matrix)\n",
    "    else:\n",
    "        print(mean_matrix.shape)\n",
    "        raise NotImplementedError(\"Dataset not present\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retrieve concepts and labels\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_concepts_mnist(concepts):\n",
    "    return concepts.reshape(concepts.shape[0] * concepts.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_concepts_kand(concepts):\n",
    "    split_concepts = np.split(concepts, concepts.shape[1], axis=1)\n",
    "    concatenated_concepts = np.concatenate(split_concepts, axis=0)\n",
    "    squeezed_concepts = np.squeeze(concatenated_concepts, axis=1)\n",
    "    return squeezed_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrive_concepts_and_labels(dataset, dataset_name, model_name, seed, layer, add):\n",
    "    true_concepts = []\n",
    "    pred_c_str = f\"../output/concept_presence_{dataset_name}_{model_name}_{seed}_{layer}{add}.npy\"\n",
    "\n",
    "    for _, _, concepts in tqdm(dataset):\n",
    "        true_concepts.append(concepts.cpu().numpy())\n",
    "\n",
    "    # concatenate\n",
    "    true_concepts = np.concatenate(true_concepts, axis=0)\n",
    "\n",
    "    predicted_concepts = np.load(pred_c_str)\n",
    "\n",
    "    predicted_concepts = predict_concepts(predicted_concepts, dataset_name)\n",
    "\n",
    "    if dataset_name in [\"shortmnist\", \"clipshortmnist\"]:\n",
    "        true_concepts = process_concepts_mnist(true_concepts)\n",
    "        predicted_concepts = process_concepts_mnist(predicted_concepts)\n",
    "\n",
    "    if dataset_name in [\"kandinsky\", \"clipkandinsky\"]:\n",
    "        true_concepts = process_concepts_kand(true_concepts)\n",
    "        predicted_concepts = process_concepts_kand(predicted_concepts)\n",
    "\n",
    "    assert (\n",
    "        true_concepts.shape == predicted_concepts.shape\n",
    "    ), f\" {true_concepts.shape}, {predicted_concepts.shape}\"\n",
    "\n",
    "    return true_concepts, predicted_concepts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Metrics:\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        collapse,\n",
    "    ):\n",
    "        self.concept_accuracy = concept_accuracy\n",
    "        self.concept_f1_macro = concept_f1_macro\n",
    "        self.concept_f1_micro = concept_f1_micro\n",
    "        self.concept_f1_weighted = concept_f1_weighted\n",
    "        self.collapse = collapse\n",
    "\n",
    "\n",
    "class BOIAMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_forward,\n",
    "        collapse_stop,\n",
    "        collapse_left,\n",
    "        collapse_right,\n",
    "        mean_collapse,\n",
    "    ):\n",
    "        super(BOIAMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            collapse,\n",
    "        )\n",
    "        self.collapse_forward = collapse_forward\n",
    "        self.collapse_stop = collapse_stop\n",
    "        self.collapse_left = collapse_left\n",
    "        self.collapse_right = collapse_right\n",
    "        self.mean_collapse = mean_collapse\n",
    "\n",
    "\n",
    "class KandMetrics(Metrics):\n",
    "    def __init__(\n",
    "        self,\n",
    "        concept_accuracy,\n",
    "        concept_f1_macro,\n",
    "        concept_f1_micro,\n",
    "        concept_f1_weighted,\n",
    "        collapse,\n",
    "        collapse_shapes,\n",
    "        collapse_color,\n",
    "        mean_collapse,\n",
    "    ):\n",
    "        super(KandMetrics, self).__init__(\n",
    "            concept_accuracy,\n",
    "            concept_f1_macro,\n",
    "            concept_f1_micro,\n",
    "            concept_f1_weighted,\n",
    "            collapse,\n",
    "        )\n",
    "        self.collapse_shapes = collapse_shapes\n",
    "        self.collapse_color = collapse_color\n",
    "        self.mean_collapse = mean_collapse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_confusion_matrix(\n",
    "    true_labels,\n",
    "    predicted_labels,\n",
    "    classes,\n",
    "    normalize=False,\n",
    "    title=None,\n",
    "    cmap=plt.cm.Oranges,\n",
    "):\n",
    "    \"\"\"\n",
    "    This function prints and plots the confusion matrix.\n",
    "    Normalization can be applied by setting `normalize=True`.\n",
    "    \"\"\"\n",
    "\n",
    "    cm = np.zeros((len(classes), len(classes)))\n",
    "    for i in range(len(true_labels)):\n",
    "        cm[true_labels[i], predicted_labels[i]] += 1\n",
    "\n",
    "    if normalize:\n",
    "        cm = cm.astype(\"float\") / cm.sum(axis=1)[:, np.newaxis]\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sns.set(font_scale=1.8)\n",
    "    sns.heatmap(\n",
    "        cm,\n",
    "        annot=False,\n",
    "        fmt=\".2f\" if normalize else \"d\",\n",
    "        cmap=cmap,\n",
    "        cbar=True,\n",
    "        xticklabels=classes,\n",
    "        yticklabels=classes,\n",
    "    )\n",
    "    if title:\n",
    "        plt.savefig(title, format=\"pdf\")\n",
    "    plt.xticks(rotation=0)\n",
    "    plt.yticks(rotation=0)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_concept_collapse(true_concepts, predicted_concepts, multilabel=False):\n",
    "    if multilabel:\n",
    "        true_concepts = convert_to_categories(true_concepts.astype(int))\n",
    "        predicted_concepts = convert_to_categories(predicted_concepts.astype(int))\n",
    "\n",
    "    return 1 - compute_coverage(confusion_matrix(true_concepts, predicted_concepts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(true_concepts, predicted_concepts, dataset_name):\n",
    "\n",
    "    # multilabel or not\n",
    "    multilabel_concept = False\n",
    "\n",
    "    if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipsddoia\"]:\n",
    "        multilabel_concept = True\n",
    "\n",
    "    if dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "        collapse_true_concepts_list = torch.tensor(true_concepts)\n",
    "        collapse_true_concepts_list = torch.split(collapse_true_concepts_list, 3, dim=1)\n",
    "        collapse_pred_concepts_list = torch.tensor(predicted_concepts)\n",
    "        collapse_pred_concepts_list = torch.split(collapse_pred_concepts_list, 3, dim=1)\n",
    "\n",
    "        collapse_true_concepts_1 = collapse_true_concepts_list[0].flatten()\n",
    "        collapse_true_concepts_2 = collapse_true_concepts_list[1].flatten()\n",
    "        collapse_true_concepts = torch.stack(\n",
    "            (collapse_true_concepts_1, collapse_true_concepts_2), dim=1\n",
    "        )\n",
    "        # to int\n",
    "        collapse_true_concepts = (\n",
    "            collapse_true_concepts[:, 0] * 3 + collapse_true_concepts[:, 1]\n",
    "        )\n",
    "        collapse_true_concepts = collapse_true_concepts.detach().numpy()\n",
    "\n",
    "        collapse_pred_concepts_1 = collapse_pred_concepts_list[0].flatten()\n",
    "        collapse_pred_concepts_2 = collapse_pred_concepts_list[1].flatten()\n",
    "        collapse_pred_concepts = torch.stack(\n",
    "            (collapse_pred_concepts_1, collapse_pred_concepts_2), dim=1\n",
    "        )\n",
    "        # to int\n",
    "        collapse_pred_concepts = (\n",
    "            collapse_pred_concepts[:, 0] * 3 + collapse_pred_concepts[:, 1]\n",
    "        )\n",
    "        collapse_pred_concepts = collapse_pred_concepts.detach().numpy()\n",
    "\n",
    "        # total collapse\n",
    "        collapse = compute_concept_collapse(\n",
    "            collapse_true_concepts, collapse_pred_concepts, multilabel_concept\n",
    "        )\n",
    "    else:\n",
    "        # total collapse\n",
    "        collapse = compute_concept_collapse(\n",
    "            true_concepts, predicted_concepts, multilabel_concept\n",
    "        )\n",
    "\n",
    "    if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipsddoia\"]:\n",
    "        # additional metrics for boia and sddoia\n",
    "        collapse_forward = compute_concept_collapse(\n",
    "            true_concepts[:, :3], predicted_concepts[:, :3], True\n",
    "        )\n",
    "        collapse_stop = compute_concept_collapse(\n",
    "            true_concepts[:, 3:9], predicted_concepts[:, 3:9], True\n",
    "        )\n",
    "        collapse_left = compute_concept_collapse(\n",
    "            true_concepts[:, 9:15], predicted_concepts[:, 9:15], True\n",
    "        )\n",
    "        collapse_right = compute_concept_collapse(\n",
    "            true_concepts[:, 15:21], predicted_concepts[:, 15:21], True\n",
    "        )\n",
    "\n",
    "        mean_collapse = np.mean(\n",
    "            [collapse_forward, collapse_stop, collapse_left, collapse_right]\n",
    "        )\n",
    "\n",
    "    elif dataset_name in [\"minikandinsky\", \"kandinsky\", \"clipkandinsky\"]:\n",
    "        # additional metrics for boia and sddoia\n",
    "        collapse_color = compute_concept_collapse(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            False,\n",
    "        )\n",
    "        collapse_shapes = compute_concept_collapse(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            False,\n",
    "        )\n",
    "\n",
    "        mean_collapse = np.mean([collapse_color, collapse_shapes])\n",
    "\n",
    "    if multilabel_concept:\n",
    "        concept_accuracy, concept_f1_macro, concept_f1_micro, concept_f1_weighted = (\n",
    "            0,\n",
    "            0,\n",
    "            0,\n",
    "            0,\n",
    "        )\n",
    "\n",
    "        for i in range(true_concepts.shape[1]):\n",
    "            concept_accuracy += accuracy_score(true_concepts[i], predicted_concepts[i])\n",
    "            concept_f1_macro += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"macro\"\n",
    "            )\n",
    "            concept_f1_micro += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"micro\"\n",
    "            )\n",
    "            concept_f1_weighted += f1_score(\n",
    "                true_concepts[i], predicted_concepts[i], average=\"weighted\"\n",
    "            )\n",
    "\n",
    "        concept_accuracy = concept_accuracy / true_concepts.shape[1]\n",
    "        concept_f1_macro = concept_f1_macro / true_concepts.shape[1]\n",
    "        concept_f1_micro = concept_f1_micro / true_concepts.shape[1]\n",
    "        concept_f1_weighted = concept_f1_weighted / true_concepts.shape[1]\n",
    "\n",
    "        label_accuracy, label_f1_macro, label_f1_micro, label_f1_weighted = 0, 0, 0, 0\n",
    "    elif dataset_name in [\"kandinsky\", \"minikandinsky\", \"clipkandinsky\"]:\n",
    "        concept_accuracy_color = accuracy_score(\n",
    "            true_concepts[:, 3:6].reshape(-1), predicted_concepts[:, 3:6].reshape(-1)\n",
    "        )\n",
    "        concept_f1_macro_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_color = f1_score(\n",
    "            true_concepts[:, 3:6].reshape(-1),\n",
    "            predicted_concepts[:, 3:6].reshape(-1),\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy_shape = accuracy_score(\n",
    "            true_concepts[:, :3].reshape(-1), predicted_concepts[:, :3].reshape(-1)\n",
    "        )\n",
    "        concept_f1_macro_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"macro\",\n",
    "        )\n",
    "        concept_f1_micro_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"micro\",\n",
    "        )\n",
    "        concept_f1_weighted_shape = f1_score(\n",
    "            true_concepts[:, :3].reshape(-1),\n",
    "            predicted_concepts[:, :3].reshape(-1),\n",
    "            average=\"weighted\",\n",
    "        )\n",
    "\n",
    "        concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape])\n",
    "        concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape])\n",
    "        concept_f1_micro = np.mean([concept_f1_micro_color, concept_f1_micro_shape])\n",
    "        concept_f1_weighted = np.mean(\n",
    "            [concept_f1_weighted_color, concept_f1_weighted_shape]\n",
    "        )\n",
    "\n",
    "    else:\n",
    "        concept_accuracy = accuracy_score(true_concepts, predicted_concepts)\n",
    "        concept_f1_macro = f1_score(true_concepts, predicted_concepts, average=\"macro\")\n",
    "        concept_f1_micro = f1_score(true_concepts, predicted_concepts, average=\"micro\")\n",
    "        concept_f1_weighted = f1_score(\n",
    "            true_concepts, predicted_concepts, average=\"weighted\"\n",
    "        )\n",
    "\n",
    "    if dataset_name in [\"boia\", \"sddoia\", \"clipboia\", \"clipsddoia\"]:\n",
    "        metrics = BOIAMetrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            collapse=collapse,\n",
    "            collapse_forward=collapse_forward,\n",
    "            collapse_stop=collapse_stop,\n",
    "            collapse_right=collapse_right,\n",
    "            collapse_left=collapse_left,\n",
    "            mean_collapse=mean_collapse,\n",
    "        )\n",
    "    elif dataset_name in [\"minikandinsky\", \"kandinsky\", \"clipkandinsky\"]:\n",
    "        metrics = KandMetrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            collapse=collapse,\n",
    "            collapse_color=collapse_color,\n",
    "            mean_collapse=mean_collapse,\n",
    "        )\n",
    "    else:\n",
    "        metrics = Metrics(\n",
    "            concept_accuracy=concept_accuracy,\n",
    "            concept_f1_macro=concept_f1_macro,\n",
    "            concept_f1_micro=concept_f1_micro,\n",
    "            concept_f1_weighted=concept_f1_weighted,\n",
    "            collapse=collapse,\n",
    "        )\n",
    "\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(datasetname, args):\n",
    "    if datasetname.lower() == \"boia\":\n",
    "        return BOIA(args)\n",
    "    if datasetname.lower() == \"sddoia\":\n",
    "        return SDDOIA(args)\n",
    "    if datasetname.lower() == \"minikandinsky\":\n",
    "        return MiniKandinsky(args)\n",
    "    if datasetname.lower() == \"shortmnist\":\n",
    "        return SHORTMNIST(args)\n",
    "    if datasetname.lower() == \"kandinsky\":\n",
    "        return Kandinsky(args)\n",
    "    if datasetname.lower() == \"clipkandinsky\":\n",
    "        return CLIPKandinsky(args)\n",
    "    if datasetname.lower() == \"clipboia\":\n",
    "        return CLIPBOIA(args)\n",
    "    if datasetname.lower() == \"clipsddoia\":\n",
    "        return CLIPSDDOIA(args)\n",
    "    if datasetname.lower() == \"clipshortmnist\":\n",
    "        return CLIPSHORTMNIST(args)\n",
    "\n",
    "    raise NotImplementedError(f\"Dataset {datasetname} missing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace(\n",
    "    backbone=\"neural\",  #\n",
    "    preprocess=0,\n",
    "    finetuning=0,\n",
    "    batch_size=1,\n",
    "    n_epochs=20,\n",
    "    validate=1,\n",
    "    dataset=\"clipsddoia\",\n",
    "    lr=0.001,\n",
    "    exp_decay=0.99,\n",
    "    warmup_steps=1,\n",
    "    wandb=None,\n",
    "    task=\"boia\",\n",
    "    boia_model=\"ce\",\n",
    "    model=\"sddoiann\",\n",
    "    c_sup=1,\n",
    "    which_c=[-1],\n",
    "    joint=True,\n",
    ")\n",
    "\n",
    "# get dataset\n",
    "dataset = get_dataset(args.dataset, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "def evaluate(test_set, dataset_name, model_name, ood_set=None):\n",
    "\n",
    "    n_files = 0\n",
    "\n",
    "    seeds = [123, 456, 789, 1011, 1213, 1415, 1617, 1819, 2021, 2223]\n",
    "    # MNIST ADD layers = [\"conv1\", \"conv2\", \"fc1\", \"fc2\"]\n",
    "    # layers = [\"conv1\", \"conv2\", \"fc1\", \"fc2\"]\n",
    "    # BOIA\n",
    "    # layers = [\"fc1\", \"fc2\", \"fc3\", \"fc4\"]\n",
    "    # KAND\n",
    "    layers = [\"conv1\", \"conv2\", \"conv3\", \"conv4\", \"conv5\", \"fc1\", \"fc2\"]\n",
    "    layers = [\"fc1\", \"fc2\"]\n",
    "    # SDDOIA\n",
    "    # layers = [\"conv1\"] #[\"conv2\", \"conv3\", \"conv4\", \"conv5\", \"conv6\", \"fc1\", \"fc2\"] # \"conv1\" [, \"conv2\", \"conv3\", \"conv4\", \"conv5\", \"conv6\", \"fc1\", \"fc2\"]\n",
    "    add = \"\"  # \"_padd_random\"\n",
    "\n",
    "    for layer in layers:\n",
    "\n",
    "        print(f\"\\n LAYER: {layer}\\n\")\n",
    "\n",
    "        # List of metics\n",
    "        in_metrics_list = []\n",
    "        ood_metrics_list = []\n",
    "        # Loop through seeds\n",
    "        for seed in seeds:\n",
    "\n",
    "            if not os.path.exists(\n",
    "                f\"../output/concept_presence_{dataset_name}_{model_name}_{seed}_{layer}{add}.npy\"\n",
    "            ):\n",
    "                print(\n",
    "                    f\"../output/concept_presence_{dataset_name}_{model_name}_{seed}_{layer}{add}.npy does not exists...\"\n",
    "                )\n",
    "                continue\n",
    "\n",
    "            n_files += 1\n",
    "\n",
    "            ind_data = retrive_concepts_and_labels(\n",
    "                test_set, dataset_name, model_name, seed, layer, add\n",
    "            )\n",
    "\n",
    "            if dataset_name in [\"sddoia\", \"boia\"] and False:\n",
    "                plot_confusion_matrix(\n",
    "                    convert_to_categories(ind_data[0][:, :3].astype(int)),\n",
    "                    convert_to_categories(ind_data[1][:, :3].astype(int)),\n",
    "                    [i for i in range(2**3)],\n",
    "                    True,\n",
    "                    \"Forward\",\n",
    "                )\n",
    "                plot_confusion_matrix(\n",
    "                    convert_to_categories(ind_data[0][:, 3:9].astype(int)),\n",
    "                    convert_to_categories(ind_data[1][:, 3:9].astype(int)),\n",
    "                    [i for i in range(2**6)],\n",
    "                    True,\n",
    "                    \"Stop\",\n",
    "                )\n",
    "                plot_confusion_matrix(\n",
    "                    convert_to_categories(ind_data[0][:, 9:15].astype(int)),\n",
    "                    convert_to_categories(ind_data[1][:, 9:15].astype(int)),\n",
    "                    [i for i in range(2**6)],\n",
    "                    True,\n",
    "                    \"Left\",\n",
    "                )\n",
    "                plot_confusion_matrix(\n",
    "                    convert_to_categories(ind_data[0][:, 15:21].astype(int)),\n",
    "                    convert_to_categories(ind_data[1][:, 15:21].astype(int)),\n",
    "                    [i for i in range(2**6)],\n",
    "                    True,\n",
    "                    \"Right\",\n",
    "                )\n",
    "            elif False:  # TODO\n",
    "                plot_confusion_matrix(\n",
    "                    ind_data[0],\n",
    "                    ind_data[1],\n",
    "                    [i for i in range(10)],\n",
    "                    True,\n",
    "                    title=f\"{dataset_name}_{model_name}_{seed}.pdf\",\n",
    "                )\n",
    "\n",
    "            if ood_set is not None:\n",
    "                out_data = retrive_concepts_and_labels(\n",
    "                    ood_set, dataset_name, model_name, seed, layer, \"ood\"\n",
    "                )\n",
    "\n",
    "            in_metrics = compute_metrics(*ind_data, dataset_name)\n",
    "            in_metrics_list.append(in_metrics)\n",
    "\n",
    "            if ood_set is not None:\n",
    "                ood_metrics = compute_metrics(*out_data, dataset_name)\n",
    "                ood_metrics_list.append(ood_metrics)\n",
    "\n",
    "        assert n_files > 1, \"At least 2 files to compare\"\n",
    "\n",
    "        # Compute standard deviation for each metric\n",
    "        for key in vars(in_metrics_list[0]):  # the key are always the same\n",
    "            # skip hidden elements\n",
    "            if not key.startswith(\"_\"):\n",
    "                # retrieve the list of values\n",
    "                in_metric_values = [\n",
    "                    getattr(metrics, key) for metrics in in_metrics_list\n",
    "                ]\n",
    "                ood_metric_values = [\n",
    "                    getattr(metrics, key) for metrics in ood_metrics_list\n",
    "                ]\n",
    "\n",
    "                # convert lists to NumPy arrays\n",
    "                in_metric_values_arr = np.array(in_metric_values)\n",
    "                ood_metric_values_arr = np.array(ood_metric_values)\n",
    "\n",
    "                # Compute the standard deviation\n",
    "                in_metric_std_dev = np.std(in_metric_values_arr)\n",
    "                ood_metric_std_dev = np.std(ood_metric_values_arr)\n",
    "\n",
    "                # Compute the mean\n",
    "                in_metric_std_mean = np.mean(in_metric_values_arr)\n",
    "                ood_metric_std_mean = np.mean(ood_metric_values_arr)\n",
    "\n",
    "                print(\n",
    "                    \"\\n{} (In): ${:.2f} \\pm {:.2f}$\".format(\n",
    "                        key.replace(\"_\", \" \").title(),\n",
    "                        round(in_metric_std_mean, 2),\n",
    "                        round(in_metric_std_dev, 2),\n",
    "                    )\n",
    "                )\n",
    "\n",
    "                if ood_set is not None:\n",
    "                    print(\n",
    "                        \"{} (OOD): ${:.2f} \\pm {:.2f}$\".format(\n",
    "                            key.replace(\"_\", \" \").title(),\n",
    "                            round(ood_metric_std_mean, 2),\n",
    "                            round(ood_metric_std_dev, 2),\n",
    "                        )\n",
    "                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get loaders\n",
    "_, _, test_loader = dataset.get_data_loaders()\n",
    "# ood loader\n",
    "ood_loader = None  # getattr(dataset, \"ood_loader\", None)\n",
    "# Evaluate\n",
    "evaluate(test_loader, args.dataset, args.model, ood_loader)  # ood_set=ood_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "reasoning-shortcuts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
