{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea36c45-8458-4792-ac81-255b13310270",
   "metadata": {},
   "outputs": [],
   "source": [
    "import types\n",
    "import warnings\n",
    "from functools import reduce\n",
    "from itertools import chain\n",
    "from pathlib import Path\n",
    "\n",
    "import cv2\n",
    "import matplotlib.pylab as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "from PIL import Image\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from matplotlib.ticker import MaxNLocator, MultipleLocator\n",
    "from tqdm import tqdm\n",
    "\n",
    "from train import get_args_parser\n",
    "from utils import utils\n",
    "from utils.tensor2image import make_grid\n",
    "\n",
    "# plt.rcParams.update(\n",
    "#     {\"text.usetex\": True, \"font.family\": \"sans-serif\", \"font.sans-serif\": [\"Helvetica\"]}\n",
    "# )\n",
    "# plt.rcParams.update(\n",
    "#     {\n",
    "#         \"text.usetex\": True,\n",
    "#         \"font.family\": \"serif\",\n",
    "#         \"font.serif\": [\"Palatino\"],\n",
    "#     }\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d6f17a-9d19-40c2-81bb-06b0704859f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset_name = \"cars\"\n",
    "# dataset_name = \"CUB_200\"\n",
    "# dataset_name = \"dogs\"\n",
    "dataset_name = \"imagenet\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da00b7dc-5837-4cb7-87df-a93cba42335c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = \"other\"\n",
    "\n",
    "model_datasets = {\n",
    "    \"datasets\": {\n",
    "        \"imagenet\": \"?\",  # Set the path to the dataset\n",
    "    },\n",
    "    \"models\": {\n",
    "        \"own_resnet34\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_resnet50\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_densenet121\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_convnext_large\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_vit_b_16\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_swin_v2_s\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_maxvit_t\": {\n",
    "            \"imagenet\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        ############################################################################################\n",
    "        \"resnet34\": {\n",
    "            \"imagenet\": \"ResNet34_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"resnet50\": {\n",
    "            \"imagenet\": \"ResNet50_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"densenet121\": {\n",
    "            \"imagenet\": \"DenseNet121_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"convnext_large\": {\n",
    "            \"imagenet\": \"ConvNeXt_Large_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"vit_b_16\": {\n",
    "            \"imagenet\": \"ViT_B_16_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"swin_v2_s\": {\n",
    "            \"imagenet\": \"Swin_V2_S_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "        \"maxvit_t\": {\n",
    "            \"imagenet\": \"MaxVit_T_Weights.IMAGENET1K_V1\",\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ba9785-03b0-4767-b813-4a7fe62b0f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = \"cropped\"\n",
    "\n",
    "model_datasets = {\n",
    "    \"datasets\": {\n",
    "        \"CUB_200\": \"?\",  # Set the path to the dataset\n",
    "        \"cars\": \"?\",  # Set the path to the dataset\n",
    "    },\n",
    "    \"models\": {\n",
    "        \"own_resnet34\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_resnet50\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_densenet121\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_convnext_tiny\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"resnet34\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"resnet50\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"densenet121\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"convnext_tiny\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"cars\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a6a04c-d2bb-47a9-9ca8-6e3ead318366",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = \"full\"\n",
    "\n",
    "model_datasets = {\n",
    "    \"datasets\": {\n",
    "        \"CUB_200\": \"?\",  # Set the path to the dataset\n",
    "        \"dogs\": \"?\",  # Set the path to the dataset\n",
    "    },\n",
    "    \"models\": {\n",
    "        \"own_resnet34\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_resnet50\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"own_densenet121\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"resnet34\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"resnet50\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "        \"densenet121\": {\n",
    "            \"CUB_200\": \"?\",  # Set the path to the checkpoint\n",
    "            \"dogs\": \"?\",  # Set the path to the checkpoint\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ce3fba2-ed5a-42ff-8dcc-03d737985fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import InterpolationMode\n",
    "from utils import presets\n",
    "\n",
    "\n",
    "def load_data(args):\n",
    "    print(\"Loading data\")\n",
    "    val_resize_size, val_crop_size = (\n",
    "        args.val_resize_size,\n",
    "        args.val_crop_size,\n",
    "    )\n",
    "    interpolation = InterpolationMode(args.interpolation)\n",
    "\n",
    "    trn_dir = Path(args.data_path).joinpath(\"train\")\n",
    "    trn_dir = str(trn_dir) if trn_dir.is_dir() else args.data_path\n",
    "    val_dir = Path(args.data_path).joinpath(\"val\")\n",
    "    val_dir = str(val_dir) if val_dir.is_dir() else args.data_path\n",
    "    data_type = getattr(args, \"data_type\", \"other\")\n",
    "\n",
    "    if data_type == \"cropped\":\n",
    "        print(\"Loading dataset test with cropped data\")\n",
    "        preprocessing = presets.ClassificationCropped(\n",
    "            resize_size=val_crop_size,\n",
    "            interpolation=interpolation,\n",
    "            use_v2=args.use_v2,\n",
    "        )\n",
    "    elif data_type == \"full\":\n",
    "        print(\"Loading dataset test with full data\")\n",
    "        preprocessing = presets.ClassificationFull(\n",
    "            resize_size=val_crop_size,\n",
    "            interpolation=interpolation,\n",
    "            train=False,\n",
    "            use_v2=args.use_v2,\n",
    "        )\n",
    "    else:\n",
    "        preprocessing = presets.ClassificationPresetEval(\n",
    "            crop_size=val_crop_size,\n",
    "            resize_size=val_resize_size,\n",
    "            interpolation=interpolation,\n",
    "            backend=args.backend,\n",
    "            use_v2=args.use_v2,\n",
    "        )\n",
    "\n",
    "    dataset = torchvision.datasets.ImageFolder(trn_dir, preprocessing)\n",
    "    dataset_test = torchvision.datasets.ImageFolder(val_dir, preprocessing)\n",
    "\n",
    "    train_sampler = torch.utils.data.SequentialSampler(dataset)\n",
    "    test_sampler = torch.utils.data.SequentialSampler(dataset_test)\n",
    "\n",
    "    return dataset, dataset_test, train_sampler, test_sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be953cc1ccf5d9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = None\n",
    "dataset = None\n",
    "\n",
    "models = {}\n",
    "datasets = {}\n",
    "for model_name in model_datasets[\"models\"].keys():\n",
    "    if dataset_name not in model_datasets[\"datasets\"]:\n",
    "        raise NotImplementedError(\"Dataset not implemented\")\n",
    "\n",
    "    params_string = f\"\"\"\n",
    "    --data-path {model_datasets['datasets'][dataset_name]}\n",
    "    --dataset-name {dataset_name}\n",
    "    --gumbel-dim {1 if model_name == 'own_vit_b_16' else -1}\n",
    "    \"\"\"\n",
    "\n",
    "    if (resume_path := model_datasets[\"models\"][model_name][dataset_name]).endswith(\n",
    "        \"IMAGENET1K_V1\"\n",
    "    ):\n",
    "        params_string += f\"--weights {resume_path}\\n\"\n",
    "    else:\n",
    "        params_string += f\"--resume {resume_path}\\n\"\n",
    "\n",
    "    params_string += \"\"\"\n",
    "    --batch-size 64  # 256\n",
    "    --workers 36\n",
    "    --use-v2\n",
    "    --output-dir /results\n",
    "    --model {0}\n",
    "    --data_type {1}\n",
    "    --val-resize-size {2}\n",
    "    --val-crop-size {3}\n",
    "    --print-freq 500\n",
    "    # --device cpu\n",
    "    \"\"\".format(\n",
    "        model_name,\n",
    "        data_type,\n",
    "        224 if \"maxvit\" in model_name else 256,\n",
    "        256 if \"swin\" in model_name else 224,\n",
    "    )\n",
    "\n",
    "    args_params = list(\n",
    "        chain(\n",
    "            *map(\n",
    "                lambda x: x.split(\"#\")[0].split(),\n",
    "                filter(\n",
    "                    lambda x: x.strip() and not x.startswith(\"#\"),\n",
    "                    params_string.split(\"\\n\"),\n",
    "                ),\n",
    "            )\n",
    "        )\n",
    "    )\n",
    "\n",
    "    args = get_args_parser().parse_args(args=args_params)\n",
    "    # print(utils.print_namespace(args))\n",
    "\n",
    "    if device is None:\n",
    "        device = torch.device(\n",
    "            \"cuda\" if torch.cuda.is_available() and args.device == \"cuda\" else \"cpu\"\n",
    "        )\n",
    "        print(f\"\\033[0;1;31mDevice: {device}\\033[0m\")\n",
    "\n",
    "    dataset, dataset_test, train_sampler, test_sampler = load_data(args)\n",
    "    num_classes = len(dataset.classes)\n",
    "    print(f\"\\033[0;34mNumber of classes in dataset: {num_classes}\\033[0m\")\n",
    "    datasets[model_name] = (dataset_test, dataset)\n",
    "\n",
    "    print(f\"Creating model: {model_name}\")\n",
    "    models[model_name] = torchvision.models.get_model(\n",
    "        args.model,\n",
    "        weights=args.weights,\n",
    "        num_classes=num_classes,\n",
    "        **({\"gumbel_dim\": args.gumbel_dim} if args.model.startswith(\"own_\") else {}),\n",
    "    )\n",
    "\n",
    "    if args.resume:\n",
    "        checkpoint = torch.load(args.resume, map_location=\"cpu\", weights_only=False)\n",
    "\n",
    "        missing_keys, unexpected_keys = models[model_name].load_state_dict(\n",
    "            checkpoint[\"model\"], strict=False\n",
    "        )\n",
    "        warning_message = \"\"\n",
    "        if missing_keys:\n",
    "            warning_message += f\"\\033[0;1;33mMissing keys: {missing_keys}\\033[0m \"\n",
    "        if unexpected_keys:\n",
    "            warning_message += f\"\\033[0;1;36mUnexpected keys: {unexpected_keys}\\033[0m\"\n",
    "        if warning_message:\n",
    "            warnings.warn(warning_message)\n",
    "\n",
    "        print(\n",
    "            f\"Resuming model from file '{args.resume}' that was trained by {checkpoint['epoch'] + 1} epochs.\"\n",
    "        )\n",
    "\n",
    "    models[model_name].to(device)\n",
    "    models[model_name].eval()\n",
    "    print(\"-\" * 75)\n",
    "print(\"\\nDone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e251986c8a098c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = {key: [] for key in models.keys()}\n",
    "for model_name, model in models.items():\n",
    "\n",
    "    data_loader_test = torch.utils.data.DataLoader(\n",
    "        datasets[model_name][0],\n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=args.workers,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    metric_logger = utils.MetricLogger(delimiter=\"  \")\n",
    "    metric_logger.add_meter(\n",
    "        \"acc1\",\n",
    "        utils.SmoothedValue(window_size=20, fmt=\"{median:.2f} ({global_avg:.2f})\"),\n",
    "    )\n",
    "    header = \"Test:\"\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        for image, target in metric_logger.log_every(\n",
    "            data_loader_test, args.print_freq, header\n",
    "        ):\n",
    "            image = image.to(device, non_blocking=True)\n",
    "            target = target.to(device, non_blocking=True)\n",
    "            output = model(image)\n",
    "\n",
    "            predictions[model_name].extend(torch.argmax(output, dim=1).tolist())\n",
    "\n",
    "            acc = getattr(utils, \"accuracy\")(output, target, topk=[1])\n",
    "            metric_logger.meters[\"acc1\"].update(acc[0].item(), n=image.shape[0])\n",
    "\n",
    "    print(\n",
    "        f\"\\033[0;33m{model_name} \\033[0;31mAcc@1: {getattr(metric_logger, 'acc1').global_avg:.2f}%\\033[0m\"\n",
    "    )\n",
    "    print(\"=\" * 100)\n",
    "print(\"Done\")\n",
    "\n",
    "for key in models.keys():\n",
    "    predictions[key] = np.array(predictions[key])\n",
    "\n",
    "np.save(\"predictions_of_models.npy\", predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ca3bf7622a04f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = np.load(\"predictions_of_models.npy\", allow_pickle=True).item()\n",
    "predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c30fc0c-bea7-410d-aa48-eb9868c57f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in predictions.keys():\n",
    "    if key.startswith(\"own_\"):\n",
    "        continue\n",
    "\n",
    "    print(\n",
    "        f'{key}: {(predictions[key] == predictions[f\"own_{key}\"]).sum() / predictions[key].size * 100:12.1f}%'\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936e0b17-a925-4a02-a89b-96bdee30c586",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "32434998e06d1f6d",
   "metadata": {},
   "source": [
    "## Transform Model to Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ae91640676385a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name, model in models.items():\n",
    "    if not model.__class__.__name__.startswith(\"Custom\"):\n",
    "        continue\n",
    "\n",
    "    print(f\"{model_name=}\")\n",
    "    if model_name in [\"own_resnet34\", \"own_resnet50\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.conv1(x)\n",
    "            x = self.bn1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.maxpool(x)\n",
    "\n",
    "            x = self.layer1(x)\n",
    "            x = self.layer2(x)\n",
    "            x = self.layer3(x)\n",
    "            x = self.layer4(x)\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(x)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            heatmap = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = torch.flatten(x, 1)\n",
    "            x = self.fc(x)\n",
    "\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name == \"own_densenet121\":\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            features = self.features(x)\n",
    "            # out = F.relu(features, inplace=True)\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(features)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            heatmap = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            out = torch.nn.functional.adaptive_avg_pool2d(heatmap, (1, 1))\n",
    "            out = torch.flatten(out, 1)\n",
    "            out = self.classifier(out)\n",
    "            return out, heatmap\n",
    "\n",
    "    elif model_name in [\"own_convnext_tiny\", \"own_convnext_large\"]:\n",
    "        eps = model.classifier[0].eps\n",
    "        gamma = model.classifier[0].weight\n",
    "        beta = model.classifier[0].bias\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.features(x)\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(x)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            x = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            mean = torch.mean(x, dim=(1, 2, 3), keepdim=True)\n",
    "            var = torch.var(\n",
    "                x.mean(dim=(2, 3), keepdim=True), dim=(1), keepdim=True, unbiased=False\n",
    "            )\n",
    "\n",
    "            heatmap = (\n",
    "                torch.div(x - mean, torch.sqrt(var + eps))\n",
    "                .mul(gamma.view(1, -1, 1, 1))\n",
    "                .add(beta.view(1, -1, 1, 1))\n",
    "            )\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = self.classifier[1:](x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"own_vit_b_16\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self._process_input(x)\n",
    "            n = x.shape[0]\n",
    "\n",
    "            # Expand the class token to the full batch\n",
    "            batch_class_token = self.class_token.expand(n, -1, -1)\n",
    "            x = torch.cat([batch_class_token, x], dim=1)\n",
    "\n",
    "            x = self.encoder(x)\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(x.permute(0, 2, 1)).permute(0, 2, 1)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            heatmap = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            # Classifier \"token\" as used by standard language architectures\n",
    "            # x = x[:, 0]\n",
    "            x = torch.mean(heatmap, dim=self.gumbel_dim)\n",
    "\n",
    "            x = self.heads(x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"own_swin_v2_s\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.features(x)\n",
    "            x = self.norm(x)\n",
    "            x = self.permute(x)  # B H W C -> B C H W\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(x)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            heatmap = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = self.flatten(x)\n",
    "            x = self.head(x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"own_maxvit_t\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.stem(x)\n",
    "            for block in self.blocks:\n",
    "                x = block(x)\n",
    "\n",
    "            ########################################\n",
    "            x = self.unitary_matrix(x)\n",
    "\n",
    "            v_positive = torch.relu(x)\n",
    "            v_positive = self.apply_gumbel_softmax(v_positive)\n",
    "\n",
    "            v_negative = torch.relu(torch.neg(x))\n",
    "            v_negative = self.apply_gumbel_softmax(v_negative)\n",
    "\n",
    "            heatmap = v_positive - v_negative\n",
    "            ########################################\n",
    "\n",
    "            x = self.classifier(heatmap)\n",
    "            return x, heatmap\n",
    "\n",
    "    ############################################################################################################\n",
    "    ############################################################################################################\n",
    "    ############################################################################################################\n",
    "\n",
    "    elif model_name in [\"resnet34\", \"resnet50\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.conv1(x)\n",
    "            x = self.bn1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.maxpool(x)\n",
    "\n",
    "            x = self.layer1(x)\n",
    "            x = self.layer2(x)\n",
    "            x = self.layer3(x)\n",
    "            heatmap = self.layer4(x)\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = torch.flatten(x, 1)\n",
    "            x = self.fc(x)\n",
    "\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"densenet121\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            features = self.features(x)\n",
    "            heatmap = torch.nn.functional.relu(features, inplace=True)\n",
    "\n",
    "            out = torch.nn.functional.adaptive_avg_pool2d(heatmap, (1, 1))\n",
    "            out = torch.flatten(out, 1)\n",
    "            out = self.classifier(out)\n",
    "            return out, heatmap\n",
    "\n",
    "    elif model_name in [\"convnext_tiny\", \"convnext_large\"]:\n",
    "        eps = model.classifier[0].eps\n",
    "        gamma = model.classifier[0].weight\n",
    "        beta = model.classifier[0].bias\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.features(x)\n",
    "\n",
    "            mean = torch.mean(x, dim=(1, 2, 3), keepdim=True)\n",
    "            var = torch.var(\n",
    "                x.mean(dim=(2, 3), keepdim=True), dim=(1), keepdim=True, unbiased=False\n",
    "            )\n",
    "\n",
    "            heatmap = (\n",
    "                torch.div(x - mean, torch.sqrt(var + eps))\n",
    "                .mul(gamma.view(1, -1, 1, 1))\n",
    "                .add(beta.view(1, -1, 1, 1))\n",
    "            )\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = self.classifier[1:](x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"vit_b_16\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self._process_input(x)\n",
    "            n = x.shape[0]\n",
    "\n",
    "            # Expand the class token to the full batch\n",
    "            batch_class_token = self.class_token.expand(n, -1, -1)\n",
    "            x = torch.cat([batch_class_token, x], dim=1)\n",
    "\n",
    "            heatmap = self.encoder(x)\n",
    "\n",
    "            # Classifier \"token\" as used by standard language architectures\n",
    "            x = heatmap[:, 0]\n",
    "\n",
    "            x = self.heads(x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"swin_v2_s\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.features(x)\n",
    "            x = self.norm(x)\n",
    "            heatmap = self.permute(x)  # B H W C -> B C H W\n",
    "\n",
    "            x = self.avgpool(heatmap)\n",
    "            x = self.flatten(x)\n",
    "            x = self.head(x)\n",
    "            return x, heatmap\n",
    "\n",
    "    elif model_name in [\"maxvit_t\"]:\n",
    "\n",
    "        def modified_fun(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:\n",
    "            x = self.stem(x)\n",
    "            for block in self.blocks:\n",
    "                x = block(x)\n",
    "\n",
    "            heatmap = x\n",
    "            x = self.classifier(heatmap)\n",
    "            return x, heatmap\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Model '{model_name}' not implemented\")\n",
    "\n",
    "    model.forward = types.MethodType(modified_fun, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec0a4e2ca421fdc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import (\n",
    "    ResNet,\n",
    "    DenseNet,\n",
    "    ConvNeXt,\n",
    "    VisionTransformer,\n",
    "    SwinTransformer,\n",
    "    MaxVit,\n",
    ")\n",
    "from models import CustomVisionTransformer\n",
    "\n",
    "\n",
    "# Select the last layer based on the model type\n",
    "def get_last_layer(model):\n",
    "    if isinstance(model, ResNet):\n",
    "        return model.fc\n",
    "    elif isinstance(model, DenseNet):\n",
    "        return model.classifier\n",
    "    elif isinstance(model, (ConvNeXt, MaxVit)):\n",
    "        return model.classifier[-1]\n",
    "    elif isinstance(model, VisionTransformer):\n",
    "        return model.heads.head\n",
    "    elif isinstance(model, SwinTransformer):\n",
    "        return model.head\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Model '{model_name}' not implemented\")\n",
    "\n",
    "\n",
    "def find_smallest_atol(tensor1: torch.Tensor, tensor2: torch.Tensor):\n",
    "    \"\"\"\n",
    "    Find the smallest atol for which two tensors are considered close.\n",
    "\n",
    "    Args:\n",
    "        tensor1 (torch.Tensor): The first tensor.\n",
    "        tensor2 (torch.Tensor): The second tensor.\n",
    "\n",
    "    Returns:\n",
    "        float: The smallest atol for which torch.allclose is True.\n",
    "        None: If tensors are not close for any atol in the range.\n",
    "    \"\"\"\n",
    "    # Check if tensors have the same shape\n",
    "    if tensor1.shape != tensor2.shape:\n",
    "        print(\"The tensors have different shapes.\")\n",
    "        return None\n",
    "\n",
    "    # Generate atol values\n",
    "    atol_values = list(range(-8, -1))\n",
    "\n",
    "    # Iterate over different atol values\n",
    "    for atol in atol_values:\n",
    "        if torch.allclose(tensor1, tensor2, atol=eval(f\"1e{atol}\")):\n",
    "            return atol\n",
    "\n",
    "    print(f\"Distance: {torch.dist(tensor1, tensor2):.4g}\")\n",
    "    return None\n",
    "\n",
    "\n",
    "def generate_heatmap(\n",
    "    model, final_layer, input_images, true_labels, device, draw_mode=True\n",
    "):\n",
    "    smallest_atol = None\n",
    "    computed_heatmap = None\n",
    "    with torch.inference_mode():\n",
    "        input_images, true_labels = map(\n",
    "            lambda tensor: tensor.to(device, non_blocking=True),\n",
    "            [input_images, true_labels],\n",
    "        )\n",
    "        logits, feature_maps = model(input_images)\n",
    "\n",
    "        predicted_labels = torch.argmax(logits, dim=1)\n",
    "        correct_predictions_mask = true_labels == predicted_labels\n",
    "        if correct_predictions_mask.sum().item():\n",
    "            input_images, logits, feature_maps, true_labels = map(\n",
    "                lambda tensor: tensor[correct_predictions_mask],\n",
    "                [input_images, logits, feature_maps, true_labels],\n",
    "            )\n",
    "            if not draw_mode:\n",
    "                if isinstance(model, (ResNet, SwinTransformer)):\n",
    "                    feature_maps = model.avgpool(feature_maps).flatten(start_dim=1)\n",
    "                elif isinstance(model, (DenseNet, ConvNeXt)):\n",
    "                    feature_maps = torch.nn.functional.adaptive_avg_pool2d(\n",
    "                        feature_maps, (1, 1)\n",
    "                    ).flatten(start_dim=1)\n",
    "                elif type(model) is CustomVisionTransformer:\n",
    "                    feature_maps = torch.mean(feature_maps, dim=model.gumbel_dim)\n",
    "                elif type(model) is VisionTransformer:\n",
    "                    feature_maps = feature_maps[:, 0]\n",
    "\n",
    "            if isinstance(model, MaxVit):\n",
    "                feature_maps = (\n",
    "                    model.classifier[:-1](feature_maps)\n",
    "                    if not draw_mode\n",
    "                    else model.classifier[:-1](feature_maps).unsqueeze(-1).unsqueeze(-1)\n",
    "                    * feature_maps.div(feature_maps.sum(dim=[-2, -1], keepdim=True))\n",
    "                )\n",
    "\n",
    "            weights = (\n",
    "                torch.abs(final_layer.weight[true_labels])\n",
    "                if model.__class__.__name__.startswith(\"Custom\")\n",
    "                else final_layer.weight[true_labels]\n",
    "            )\n",
    "            if isinstance(model, VisionTransformer):\n",
    "                computed_heatmap = torch.einsum(\n",
    "                    \"bi,bai->bia\" if draw_mode else \"bi,bi->bi\",\n",
    "                    weights,\n",
    "                    feature_maps,\n",
    "                )\n",
    "                if final_layer.bias is not None:\n",
    "                    bias_view_shape = (-1, *((1,) * 2 if draw_mode else (1,)))\n",
    "                    computed_heatmap += (\n",
    "                        final_layer.bias[true_labels].view(bias_view_shape)\n",
    "                        / computed_heatmap.shape[1]\n",
    "                    )\n",
    "\n",
    "                # Assert model output matches heatmap values\n",
    "                smallest_atol = find_smallest_atol(\n",
    "                    logits[range(computed_heatmap.shape[0]), true_labels],\n",
    "                    (\n",
    "                        computed_heatmap.sum(dim=[-2, -1]).div(\n",
    "                            computed_heatmap.shape[-1]\n",
    "                        )\n",
    "                        if draw_mode\n",
    "                        else computed_heatmap.sum(dim=-1)\n",
    "                    ),\n",
    "                )\n",
    "                assert (\n",
    "                    smallest_atol is not None\n",
    "                ), \"Tensors are not close with any atol in the specified range.\"\n",
    "                # print(f\"Tensors are close with atol=1e{smallest_atol}\")\n",
    "\n",
    "                if draw_mode:\n",
    "                    n, c, h, w = input_images.shape\n",
    "                    n_h = h // model.patch_size\n",
    "                    n_w = w // model.patch_size\n",
    "                    computed_heatmap = computed_heatmap[..., 1:].view(n, -1, n_h, n_w)\n",
    "            else:\n",
    "                computed_heatmap = torch.einsum(\n",
    "                    \"bi,biwh->biwh\" if draw_mode else \"bi,bi->bi\",\n",
    "                    (weights, feature_maps),\n",
    "                )\n",
    "                if final_layer.bias is not None:\n",
    "                    bias_view_shape = (-1, *((1,) * 3 if draw_mode else (1,)))\n",
    "                    computed_heatmap += (\n",
    "                        final_layer.bias[true_labels].view(bias_view_shape)\n",
    "                        / computed_heatmap.shape[1]\n",
    "                    )\n",
    "\n",
    "                # Assert model output matches heatmap values\n",
    "                smallest_atol = find_smallest_atol(\n",
    "                    logits[range(computed_heatmap.shape[0]), true_labels],\n",
    "                    (\n",
    "                        (\n",
    "                            computed_heatmap.sum(dim=[-3, -2, -1])\n",
    "                            if isinstance(model, MaxVit)\n",
    "                            else computed_heatmap.sum(dim=[-3, -2, -1]).div(\n",
    "                                np.prod(computed_heatmap.shape[-2:])\n",
    "                            )\n",
    "                        )\n",
    "                        if draw_mode\n",
    "                        else computed_heatmap.sum(dim=-1)\n",
    "                    ),\n",
    "                )\n",
    "                assert (\n",
    "                    smallest_atol is not None\n",
    "                ), \"Tensors are not close with any atol in the specified range.\"\n",
    "                # print(f\"Tensors are close with atol=1e{smallest_atol}\")\n",
    "\n",
    "    return computed_heatmap, correct_predictions_mask, smallest_atol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9c3277-4d3d-4b83-a248-34c6ac9fa2b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4a5ba057-7b85-4124-a024-8d9f4b723758",
   "metadata": {},
   "source": [
    "# Ex. I. 95% explanation of decision model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c0f4ef9702286a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Directory setup\n",
    "save_examples = \"examples\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# CUDNN setup\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "# Threshold for importance\n",
    "threshold = 0.95\n",
    "\n",
    "# Processing models\n",
    "for model_name, model in models.items():\n",
    "\n",
    "    if Path(\n",
    "        f\"{save_examples}/{threshold}-explained_pred_{model_name}_{args.dataset_name}.npz\"\n",
    "    ).is_file():\n",
    "        continue\n",
    "\n",
    "    max_atol = -100\n",
    "    data_loader_test = torch.utils.data.DataLoader(\n",
    "        datasets[model_name][0],\n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=args.workers,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    data2save = {\n",
    "        key: {\n",
    "            \"idx_img_correct_label\": [],\n",
    "            \"num_important_channels\": [],\n",
    "            \"idx_important_channels\": [],\n",
    "            \"values_channels\": [],\n",
    "        }\n",
    "        for key in data_loader_test.dataset.class_to_idx.values()\n",
    "    }\n",
    "\n",
    "    last_layer = get_last_layer(model)\n",
    "\n",
    "    n = 0\n",
    "    for images, targets in (\n",
    "        bar := tqdm(data_loader_test, desc=\"Progress\", position=0, ncols=100)\n",
    "    ):\n",
    "        heatmap, idx, atol = generate_heatmap(\n",
    "            model, last_layer, images, targets, device, draw_mode=False\n",
    "        )\n",
    "\n",
    "        n += images.shape[0]\n",
    "        if heatmap is not None:\n",
    "            max_atol = max(atol, max_atol)\n",
    "            with torch.inference_mode():\n",
    "                idx_images = torch.arange(n - images.shape[0], n, device=device)[idx]\n",
    "\n",
    "                values = torch.abs(heatmap)\n",
    "                sorted_values, indices = torch.sort(values, dim=-1, descending=True)\n",
    "                num_channels_explained_percentage = (\n",
    "                    torch.argmax(\n",
    "                        (\n",
    "                            torch.cumsum(sorted_values, dim=1)\n",
    "                            / values.sum(dim=1, keepdim=True)\n",
    "                            >= threshold\n",
    "                        ).long(),\n",
    "                        dim=1,\n",
    "                    )\n",
    "                    + 1\n",
    "                )\n",
    "\n",
    "                idx_images = idx_images.cpu().numpy()\n",
    "                heatmap = heatmap.cpu().numpy()\n",
    "                num_channels_explained_percentage = (\n",
    "                    num_channels_explained_percentage.cpu().numpy()\n",
    "                )\n",
    "                targets = targets[idx.cpu()].numpy()\n",
    "                indices = indices.cpu().numpy()\n",
    "\n",
    "                for i in range(targets.shape[0]):\n",
    "                    data2save[targets[i]][\"idx_img_correct_label\"].append(idx_images[i])\n",
    "                    data2save[targets[i]][\"values_channels\"].append(heatmap[i])\n",
    "                    data2save[targets[i]][\"num_important_channels\"].append(\n",
    "                        num_channels_explained_percentage[i]\n",
    "                    )\n",
    "                    data2save[targets[i]][\"idx_important_channels\"].append(\n",
    "                        \"-\".join(\n",
    "                            map(str, indices[i, : num_channels_explained_percentage[i]])\n",
    "                        )\n",
    "                    )\n",
    "        bar.set_description(f\"{model_name} | epsilon=1e{max_atol}\")\n",
    "\n",
    "    for key in data2save:\n",
    "        if data2save[key][\"values_channels\"]:\n",
    "            data2save[key][\"values_channels\"] = np.stack(\n",
    "                data2save[key][\"values_channels\"]\n",
    "            )\n",
    "\n",
    "    # Flatten dictionary and save\n",
    "    flat_data_dict = {\n",
    "        f\"{outer_key}-{inner_key}\": np.array(value)\n",
    "        for outer_key, inner_dict in data2save.items()\n",
    "        for inner_key, value in inner_dict.items()\n",
    "    }\n",
    "    np.savez(\n",
    "        f\"{save_examples}/{threshold}-explained_pred_{model_name}_{args.dataset_name}.npz\",\n",
    "        **flat_data_dict,\n",
    "    )\n",
    "    print(\n",
    "        f\"Data saved successfully in '{save_examples}/{threshold}-explained_pred_{model_name}_{args.dataset_name}.npz'\"\n",
    "    )\n",
    "\n",
    "print(\"\\nDone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63e5aa9f-856c-4244-83de-990af90e924b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1b9ca1f3-a114-4ead-9a1b-2e3c8c550b20",
   "metadata": {},
   "source": [
    "### Charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2959423-309a-448d-8be8-d1b7fc0ae7ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_plot = True\n",
    "# save_plot = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ff6d0b-bac9-40b9-8b51-6fbc6d5e6f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", palette=\"colorblind\")\n",
    "sns.set_context(\"paper\", font_scale=2.0)\n",
    "\n",
    "\n",
    "def load_data2plot(models, dataname, file_template):\n",
    "    dict_data = {\"Value\": [], \"Model\": []}\n",
    "    for model in models:\n",
    "        loaded_data = np.load(file_template.format(model, dataname))\n",
    "        data = [\n",
    "            item\n",
    "            for key in loaded_data.files\n",
    "            if key.endswith(\"num_important_channels\")\n",
    "            for item in loaded_data[key]\n",
    "        ]\n",
    "        dict_data[\"Value\"].extend(data)\n",
    "        dict_data[\"Model\"].extend([model] * len(data))\n",
    "    return pd.DataFrame.from_dict(dict_data)\n",
    "\n",
    "\n",
    "models_dict = {\n",
    "    \"own_resnet34\": \"InfoDisent(ResNet-34)\",\n",
    "    \"own_resnet50\": \"InfoDisent(ResNet-50)\",\n",
    "    \"own_densenet121\": \"InfoDisent(Densenet-121)\",\n",
    "    \"own_convnext_tiny\": \"InfoDisent(Convnext-Tiny)\",\n",
    "    \"own_convnext_large\": \"InfoDisent(Convnext-Large)\",\n",
    "    \"own_vit_b_16\": \"InfoDisent(VisionTransformer)\",\n",
    "    \"own_swin_v2_s\": \"InfoDisent(SwinTransformer)\",\n",
    "    \"own_maxvit_t\": \"InfoDisent(MaxVit)\",\n",
    "    \"resnet34\": \"ResNet-34\",\n",
    "    \"resnet50\": \"ResNet-50\",\n",
    "    \"densenet121\": \"Densenet-121\",\n",
    "    \"convnext_tiny\": \"Convnext-Tiny\",\n",
    "    \"convnext_large\": \"Convnext-Large\",\n",
    "    \"vit_b_16\": \"VisionTransformer\",\n",
    "    \"swin_v2_s\": \"SwinTransformer\",\n",
    "    \"maxvit_t\": \"MaxVit\",\n",
    "}\n",
    "\n",
    "\n",
    "for model_name in models.keys():\n",
    "    if not model_name.startswith(\"own_\"):\n",
    "        continue\n",
    "\n",
    "    group = [i for i in models.keys() if model_name[4:] in i]\n",
    "\n",
    "    df = load_data2plot(\n",
    "        group, dataset_name.lower(), \"examples/0.95-explained_pred_{}_{}.npz\"\n",
    "    )\n",
    "\n",
    "    group.remove(model_name)\n",
    "    model_mapping = {\n",
    "        model_name: models_dict[model_name],\n",
    "        group[0]: models_dict[model_name[4:]],\n",
    "    }\n",
    "\n",
    "    df[\"Model\"] = df[\"Model\"].replace(model_mapping)\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(12, 3))\n",
    "    sns.kdeplot(\n",
    "        data=df,\n",
    "        x=\"Value\",\n",
    "        hue=\"Model\",\n",
    "        fill=True,\n",
    "        common_norm=False,\n",
    "        legend=True,\n",
    "        ax=ax,\n",
    "    )\n",
    "    ax.set_xlabel(\"Number of Important Channels\")\n",
    "    sns.move_legend(\n",
    "        ax, \"lower center\", bbox_to_anchor=(0.5, 0.8), ncol=3, title=None, frameon=True\n",
    "    )\n",
    "\n",
    "    xmin, xmax = ax.get_xlim()\n",
    "    if xmin < 0:\n",
    "        ax.set_xlim(left=0)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_plot:\n",
    "        plt.savefig(\n",
    "            f\"examples/distribution_important_channels_{model_name}_{dataset_name.lower()}.pdf\",\n",
    "            transparent=True,\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )\n",
    "\n",
    "    plt.show()\n",
    "    plt.close(fig)\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18c683f-a9f0-465f-aba5-3ee1dc7a7421",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51d45092-2997-41cf-88c6-11edb8dc9894",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", palette=\"colorblind\")\n",
    "sns.set_context(\"paper\", font_scale=1.5)\n",
    "\n",
    "k = 0  # index of class\n",
    "idx_img = 10  # index of image\n",
    "\n",
    "for model_name in models.keys():\n",
    "    if not model_name.startswith(\"own_\"):\n",
    "        continue\n",
    "\n",
    "    group = [i for i in models.keys() if model_name[4:] in i]\n",
    "    group.remove(model_name)\n",
    "    group = [model_name] + group\n",
    "    # model_mapping = {model_name: models_dict[model_name], group[0]: models_dict[model_name[4:]]}\n",
    "\n",
    "    fig, axs = plt.subplots(len(group), 1, figsize=(12, 4))\n",
    "\n",
    "    for n_gr, model_name in enumerate(group):\n",
    "        filename = (\n",
    "            f\"examples/0.95-explained_pred_{model_name}_{dataset_name.lower()}.npz\"\n",
    "        )\n",
    "        loaded_data = np.load(filename)\n",
    "\n",
    "        data2save = {}\n",
    "        for key in loaded_data:\n",
    "            outer_key, inner_key = key.split(\"-\")\n",
    "            outer_key = int(outer_key)\n",
    "            if outer_key not in data2save:\n",
    "                data2save[outer_key] = {}\n",
    "            data2save[outer_key][inner_key] = loaded_data[key]\n",
    "\n",
    "        data = data2save[k][\"values_channels\"][idx_img]\n",
    "        mask = data > 0\n",
    "\n",
    "        df = pd.DataFrame({\"value\": data, \"mask\": np.array(mask, dtype=np.int32)})\n",
    "\n",
    "        ax = axs[n_gr]\n",
    "        sns.barplot(x=df.index, y=\"value\", data=df, hue=\"mask\", ax=ax, width=3)\n",
    "        # ax.set_title(f'{model_name}')\n",
    "        ax.set_ylabel(\"\")\n",
    "        ax.set_xlabel(\"Channel\")\n",
    "        ax.tick_params(axis=\"x\", rotation=45)\n",
    "\n",
    "        # ax.xaxis.set_major_locator(MultipleLocator(20))\n",
    "        ax.xaxis.set_major_locator(MaxNLocator(nbins=20))\n",
    "\n",
    "        if n_gr == 0:\n",
    "            sns.move_legend(\n",
    "                ax,\n",
    "                \"lower center\",\n",
    "                bbox_to_anchor=(0.5, 0.98),\n",
    "                ncol=3,\n",
    "                title=None,\n",
    "                frameon=True,\n",
    "                labels=[\"Negative\", \"Positive\"],\n",
    "            )\n",
    "            ax.set_xticklabels([])\n",
    "            ax.set_xlabel(\"\")\n",
    "\n",
    "            fig.text(\n",
    "                0.985,\n",
    "                # 0.99,\n",
    "                0.7,\n",
    "                \"InfoDisent\",\n",
    "                ha=\"center\",\n",
    "                va=\"center\",\n",
    "                rotation=\"vertical\",\n",
    "            )\n",
    "        else:\n",
    "            ax.get_legend().remove()\n",
    "            fig.text(\n",
    "                0.985,\n",
    "                # 0.99,\n",
    "                0.35,\n",
    "                models_dict[group[n_gr]],\n",
    "                ha=\"center\",\n",
    "                va=\"center\",\n",
    "                rotation=\"vertical\",\n",
    "            )\n",
    "\n",
    "    fig.text(0.0, 0.5, \"Value\", ha=\"center\", va=\"center\", rotation=\"vertical\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_plot:\n",
    "        plt.savefig(\n",
    "            f\"examples/image-{idx_img}_important_channels_{model_name}_{dataset_name.lower()}.pdf\",\n",
    "            transparent=True,\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )\n",
    "\n",
    "    plt.show()\n",
    "    plt.close(fig)\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28d674a-b859-4b56-a143-0b595570e7f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37d6d2a3-6a21-4b21-aa12-a9a2efbeeef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", palette=\"colorblind\")\n",
    "sns.set_context(\"paper\", font_scale=1.75)\n",
    "\n",
    "nrows = 4\n",
    "\n",
    "for model_name in models.keys():\n",
    "    filename = f\"examples/0.95-explained_pred_{model_name}_{dataset_name.lower()}.npz\"\n",
    "    loaded_data = np.load(filename)\n",
    "\n",
    "    data2save = {}\n",
    "    for key in loaded_data:\n",
    "        outer_key, inner_key = key.split(\"-\")\n",
    "        outer_key = int(outer_key)\n",
    "        if outer_key not in data2save:\n",
    "            data2save[outer_key] = {}\n",
    "        data2save[outer_key][inner_key] = loaded_data[key]\n",
    "\n",
    "    data = {\n",
    "        key: sub_dict[\"num_important_channels\"] for key, sub_dict in data2save.items()\n",
    "    }\n",
    "\n",
    "    data_list = [(key, value) for key, values in data.items() for value in values]\n",
    "    df = pd.DataFrame(data_list, columns=[\"Class\", \"Value\"])\n",
    "\n",
    "    unique_classes = df[\"Class\"].unique()\n",
    "    parts = np.array_split(unique_classes, nrows)\n",
    "\n",
    "    fig, axs = plt.subplots(\n",
    "        nrows, 1, figsize=(12, 6), sharey=True, gridspec_kw={\"wspace\": 0, \"hspace\": 0.5}\n",
    "    )\n",
    "    for i, part in enumerate(parts):\n",
    "        ax = axs[i] if nrows > 1 else axs\n",
    "        sns.violinplot(\n",
    "            data=df[df[\"Class\"].isin(part)],\n",
    "            x=\"Class\",\n",
    "            y=\"Value\",\n",
    "            hue=\"Class\",\n",
    "            inner=None,\n",
    "            bw_method=\"scott\",\n",
    "            split=False,\n",
    "            width=0.9,\n",
    "            palette=sns.color_palette(\"Spectral\", as_cmap=True),\n",
    "            legend=False,\n",
    "            ax=ax,\n",
    "        )\n",
    "        ax.xaxis.set_major_locator(MultipleLocator(10))\n",
    "        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))\n",
    "        ax.set_ylabel(\"\")\n",
    "        ax.set_xlabel(\"\")\n",
    "        # ax.set_ylim(bottom=0)\n",
    "\n",
    "    fig.text(0.5, 0.03, \"Class\", ha=\"center\", va=\"center\")\n",
    "    fig.text(\n",
    "        0.08,\n",
    "        0.5,\n",
    "        \"Distribution of important channels\",\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "        rotation=\"vertical\",\n",
    "    )\n",
    "    fig.text(\n",
    "        0.92,\n",
    "        0.5,\n",
    "        models_dict[model_name],\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "        rotation=\"vertical\",\n",
    "    )\n",
    "\n",
    "    if save_plot:\n",
    "        plt.savefig(\n",
    "            f\"examples/important_channels_{model_name}_{dataset_name.lower()}.pdf\",\n",
    "            transparent=True,\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )\n",
    "\n",
    "    plt.show()\n",
    "    plt.close(fig)\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c65b8127-4238-48e1-95ff-b80258137338",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a039e4f9-b97c-4ea9-a532-8797290157f2",
   "metadata": {},
   "source": [
    "# Ex. II. For a given image, I select relevant channels and then select 5-10 relevant prototypes of a given channel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129fb969-54e6-4159-b83f-f1965617c937",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "1. Processes all images from the test set for which the prediction matches the target.\n",
    "2. Saves for such images the top-k values of positive activations and the indices \n",
    "   of the channels where these values occurred.\n",
    "\"\"\"\n",
    "\n",
    "# Set up directories and configurations\n",
    "save_examples = \"examples\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "topk = 5\n",
    "\n",
    "\n",
    "# Helper function to process model output\n",
    "def process_model_output(model, data_loader_test, device, topk):\n",
    "    data2save = {\n",
    "        key: {\n",
    "            \"idx_img_correct_label\": [],\n",
    "            \"idx_important_channels\": [],\n",
    "            \"values_channels\": [],\n",
    "        }\n",
    "        for key in data_loader_test.dataset.class_to_idx.values()\n",
    "    }\n",
    "\n",
    "    last_layer = get_last_layer(model)\n",
    "\n",
    "    max_atol = -100\n",
    "    n = 0\n",
    "    for images, targets in (\n",
    "        bar := tqdm(data_loader_test, desc=\"Progress\", position=0, ncols=100)\n",
    "    ):\n",
    "        heatmap, idx, atol = generate_heatmap(\n",
    "            model, last_layer, images, targets, device, draw_mode=False\n",
    "        )\n",
    "\n",
    "        n += images.shape[0]\n",
    "        if heatmap is not None:\n",
    "            max_atol = max(atol, max_atol)\n",
    "            with torch.inference_mode():\n",
    "                images_idx = torch.arange(n - images.shape[0], n, device=device)[idx]\n",
    "\n",
    "                topk_values, indices = torch.topk(heatmap, topk, dim=-1)\n",
    "                images_idx, topk_values, indices = map(\n",
    "                    lambda x: x.cpu().numpy(),\n",
    "                    [images_idx, topk_values, indices],\n",
    "                )\n",
    "                targets = targets[idx.cpu()].numpy()\n",
    "\n",
    "                for i in range(targets.shape[0]):\n",
    "                    data2save[targets[i]][\"idx_img_correct_label\"].append(images_idx[i])\n",
    "                    data2save[targets[i]][\"values_channels\"].append(topk_values[i])\n",
    "                    data2save[targets[i]][\"idx_important_channels\"].append(indices[i])\n",
    "        bar.set_description(f\"{model.__class__.__name__} | epsilon=1e{max_atol}\")\n",
    "\n",
    "    # Process data for saving\n",
    "    remove_keys = []\n",
    "    for key, val in data2save.items():\n",
    "        try:\n",
    "            data2save[key][\"values_channels\"] = np.stack(val[\"values_channels\"])\n",
    "            data2save[key][\"idx_important_channels\"] = np.stack(\n",
    "                val[\"idx_important_channels\"]\n",
    "            )\n",
    "        except ValueError:\n",
    "            print(\n",
    "                f\"Class '{key}' does not have important channels [len: {len(val['values_channels'])}]\"\n",
    "            )\n",
    "            remove_keys.append(key)\n",
    "\n",
    "    for key in remove_keys:\n",
    "        del data2save[key]\n",
    "\n",
    "    # Flatten the dictionary\n",
    "    flat_data_dict = {\n",
    "        f\"{outer_key}-{inner_key}\": np.array(value)\n",
    "        for outer_key, inner_dict in data2save.items()\n",
    "        for inner_key, value in inner_dict.items()\n",
    "    }\n",
    "\n",
    "    return flat_data_dict\n",
    "\n",
    "\n",
    "# Main processing loop\n",
    "for model_name, model in models.items():\n",
    "    if (\n",
    "        not model.__class__.__name__.startswith(\"Custom\")\n",
    "        or Path(\n",
    "            f\"{save_examples}/important_channels-{model_name}_{args.dataset_name}.npz\"\n",
    "        ).is_file()\n",
    "    ):\n",
    "        continue\n",
    "\n",
    "    data_loader_test = torch.utils.data.DataLoader(\n",
    "        datasets[model_name][0],\n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=args.workers,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    flat_data_dict = process_model_output(model, data_loader_test, device, topk)\n",
    "    filename = (\n",
    "        f\"{save_examples}/important_channels-{model_name}_{args.dataset_name}.npz\"\n",
    "    )\n",
    "    np.savez(filename, **flat_data_dict)\n",
    "    print(f\"Data saved successfully in '{filename}'\")\n",
    "\n",
    "print(\"\\nDone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5b31f6f-6c6c-4afd-a051-74591b29808b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "1. Processes all images from the training set for which the prediction matches the target.\n",
    "2. Saves an array (number of channels, top-k) where the indices of the images with the highest \n",
    "   activations in the given channels will be stored.\n",
    "\"\"\"\n",
    "\n",
    "# Set up directories and configurations\n",
    "save_examples = \"examples\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "topk = 10\n",
    "\n",
    "\n",
    "# Function to process and save data\n",
    "def process_and_save_data(model, data_loader, device, save_path, topk):\n",
    "    values_channels = None\n",
    "    indices_channels = None\n",
    "\n",
    "    last_layer = get_last_layer(model)\n",
    "\n",
    "    max_atol = -100\n",
    "    n = 0\n",
    "    for images, targets in (\n",
    "        bar := tqdm(data_loader, desc=\"Progress\", position=0, ncols=100)\n",
    "    ):\n",
    "        heatmap, idx, atol = generate_heatmap(\n",
    "            model, last_layer, images, targets, device, draw_mode=False\n",
    "        )\n",
    "\n",
    "        n += images.shape[0]\n",
    "        if heatmap is not None:\n",
    "            max_atol = max(atol, max_atol)\n",
    "            with torch.inference_mode():\n",
    "                images_idx = torch.arange(n - images.shape[0], n, device=device)[idx]\n",
    "\n",
    "                if values_channels is None:\n",
    "                    values_channels = heatmap.clone()\n",
    "                    indices_channels = torch.tile(\n",
    "                        images_idx.view(-1, 1), (1, heatmap.shape[1])\n",
    "                    ).clone()\n",
    "                else:\n",
    "                    values_channels = torch.cat([values_channels, heatmap], dim=0)\n",
    "                    indices_channels = torch.cat(\n",
    "                        [\n",
    "                            indices_channels,\n",
    "                            torch.tile(images_idx.view(-1, 1), (1, heatmap.shape[1])),\n",
    "                        ],\n",
    "                        dim=0,\n",
    "                    )\n",
    "\n",
    "                if values_channels.shape[0] > topk:\n",
    "                    values_channels, indices = torch.topk(values_channels, topk, dim=0)\n",
    "                    indices_channels = torch.gather(indices_channels, 0, indices)\n",
    "        bar.set_description(f\"{model.__class__.__name__} | epsilon=1e{max_atol}\")\n",
    "\n",
    "    np.savez(\n",
    "        save_path,\n",
    "        prototypes_values=values_channels.detach().cpu().numpy(),\n",
    "        prototypes_indices_img=indices_channels.detach().cpu().numpy(),\n",
    "    )\n",
    "    print(f\"Data saved successfully in '{save_path}'\")\n",
    "\n",
    "\n",
    "# Process each model\n",
    "for model_name, model in models.items():\n",
    "    if (\n",
    "        not model.__class__.__name__.startswith(\"Custom\")\n",
    "        or Path(\n",
    "            f\"{save_examples}/prototypes_channels-{model_name}_{args.dataset_name}.npz\"\n",
    "        ).is_file()\n",
    "    ):\n",
    "        continue\n",
    "\n",
    "    data_loader = torch.utils.data.DataLoader(\n",
    "        datasets[model_name][1],\n",
    "        batch_size=args.batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=args.workers,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    save_path = (\n",
    "        f\"{save_examples}/prototypes_channels-{model_name}_{args.dataset_name}.npz\"\n",
    "    )\n",
    "    process_and_save_data(model, data_loader, device, save_path, topk)\n",
    "\n",
    "print(\"\\nDone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae44c7ef7a37efe9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c50a34c3-e91d-46cb-a24e-01ce6ab94a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of model names and selection of a specific model\n",
    "models_name = [\n",
    "    \"own_resnet34\",\n",
    "    \"own_resnet50\",\n",
    "    \"own_densenet121\",\n",
    "    \"own_convnext_tiny\",\n",
    "    \"own_convnext_large\",\n",
    "    \"own_vit_b_16\",\n",
    "    \"own_swin_v2_s\",\n",
    "    \"own_maxvit_t\",\n",
    "]\n",
    "model_name = models_name[1]\n",
    "\n",
    "# ==================================================================\n",
    "#\n",
    "#      Here you can change the model and parameters\n",
    "#\n",
    "# ==================================================================\n",
    "\n",
    "# Check if the selected model exists\n",
    "assert (\n",
    "    model_name in models\n",
    "), f\"Model {model_name} does not exist for dataset {args.dataset_name}\"\n",
    "\n",
    "# Load data for the selected model\n",
    "important_channels_file = (\n",
    "    f\"examples/important_channels-{model_name}_{args.dataset_name}.npz\"\n",
    ")\n",
    "prototypes_channels_file = (\n",
    "    f\"examples/prototypes_channels-{model_name}_{args.dataset_name}.npz\"\n",
    ")\n",
    "\n",
    "loaded_data = np.load(important_channels_file)\n",
    "data_prototypes = np.load(prototypes_channels_file)\n",
    "\n",
    "# Restore the original nested dictionary structure\n",
    "data2save = {}\n",
    "for key in loaded_data:\n",
    "    outer_key, inner_key = key.split(\"-\")\n",
    "    outer_key = int(outer_key)\n",
    "    if outer_key not in data2save:\n",
    "        data2save[outer_key] = {}\n",
    "    data2save[outer_key][inner_key] = loaded_data[key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee5fa43b-cae1-403f-87aa-d28100616866",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_areas_rect(heatmap, threshold, min_size=5, draw_for_not_found=False):\n",
    "    _, binary_heatmap = cv2.threshold(heatmap, threshold, 1, cv2.THRESH_BINARY)\n",
    "    binary_heatmap = (binary_heatmap * 255).astype(np.uint8)\n",
    "    contours, _ = cv2.findContours(\n",
    "        binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n",
    "    )\n",
    "\n",
    "    rect_positions = []\n",
    "    if len(contours) > 0:\n",
    "        for rect in contours:\n",
    "            x, y, w, h = cv2.boundingRect(rect)\n",
    "\n",
    "            # Ensure minimum size\n",
    "            if w < min_size:\n",
    "                x = max(0, x - (min_size - w) // 2)\n",
    "                w = min_size\n",
    "            if h < min_size:\n",
    "                y = max(0, y - (min_size - h) // 2)\n",
    "                h = min_size\n",
    "\n",
    "            # Adjust dimensions to fit within heatmap bounds\n",
    "            w = min(w, heatmap.shape[1] - x)\n",
    "            h = min(h, heatmap.shape[0] - y)\n",
    "            rect_positions.append((x, y, w, h))\n",
    "    elif draw_for_not_found:\n",
    "        _, binary_heatmap = cv2.threshold(heatmap, 0, 1, cv2.THRESH_BINARY)\n",
    "        binary_heatmap = (binary_heatmap * 255).astype(np.uint8)\n",
    "        contours, _ = cv2.findContours(\n",
    "            binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n",
    "        )\n",
    "\n",
    "        if contours:\n",
    "            max_contour = max(contours, key=cv2.contourArea)\n",
    "            x, y, w, h = cv2.boundingRect(max_contour)\n",
    "\n",
    "            max_contour = max(contours, key=cv2.contourArea)\n",
    "            x, y, w, h = cv2.boundingRect(max_contour)\n",
    "\n",
    "            center_x = x + w // 2\n",
    "            center_y = y + h // 2\n",
    "\n",
    "            rect_x = max(0, center_x - min_size // 2)\n",
    "            rect_y = max(0, center_y - min_size // 2)\n",
    "            rect_positions.append(\n",
    "                (\n",
    "                    rect_x,\n",
    "                    rect_y,\n",
    "                    min(min_size, heatmap.shape[1] - rect_x),\n",
    "                    min(min_size, heatmap.shape[0] - rect_y),\n",
    "                )\n",
    "            )\n",
    "\n",
    "    return rect_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb19ca6c8c85c4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "model = models[model_name]\n",
    "dataset_test = datasets[model_name][0]\n",
    "dataset = datasets[model_name][1]\n",
    "\n",
    "color_yellow = [0, 255, 255]\n",
    "width_box = 2\n",
    "alpha = 0.7\n",
    "# ------------------\n",
    "threshold_rect = 0.01\n",
    "min_box_size = 20  # pixels\n",
    "draw_always = True\n",
    "# ------------------\n",
    "\n",
    "# Custom color map\n",
    "colors = [(0, \"white\"), (1, \"#d11919\")]\n",
    "custom_cmap = LinearSegmentedColormap.from_list(\"custom_color\", colors)\n",
    "\n",
    "last_layer = get_last_layer(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d09d60801132074",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to generate heatmap and process images\n",
    "def generate_heatmap_and_boxes(model, images, targets, last_layer, relevant_channel):\n",
    "    targets = torch.tensor(targets) if isinstance(targets, list) else targets\n",
    "    heatmap, idx, atol = generate_heatmap(\n",
    "        model, last_layer, images, targets, device, draw_mode=True\n",
    "    )\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        # Normalize heatmap for visualization\n",
    "        heatmap = heatmap / torch.flatten(heatmap, start_dim=1, end_dim=-1).abs().max(\n",
    "            dim=-1\n",
    "        ).values.view(\n",
    "            -1, 1, 1, 1\n",
    "        )  # changing value to [-1, 1]\n",
    "\n",
    "        heatmap = heatmap * (heatmap > 0)  # changing value to [0, 1]\n",
    "\n",
    "        # Resize heatmap to image size\n",
    "        heatmap_np = (\n",
    "            torch.nn.functional.interpolate(\n",
    "                heatmap[:, relevant_channel : relevant_channel + 1],\n",
    "                images.shape[-2:],\n",
    "                mode=\"bilinear\",\n",
    "            )\n",
    "            .squeeze(1)\n",
    "            .detach()\n",
    "            .cpu()\n",
    "            .numpy()\n",
    "        )\n",
    "\n",
    "        # find the smallest rectangles that contain areas of interest within a heatmap\n",
    "        rect_positions = {}\n",
    "        for ii in range(heatmap_np.shape[0]):\n",
    "            rect_positions[ii] = find_areas_rect(\n",
    "                heatmap_np[ii],\n",
    "                threshold_rect,\n",
    "                min_size=min_box_size,\n",
    "                draw_for_not_found=draw_always,\n",
    "            )\n",
    "\n",
    "        # Convert input images to numpy array and normalize them\n",
    "        images_np = images.permute(0, 2, 3, 1).detach().cpu().numpy()\n",
    "        min_ = (\n",
    "            images_np.reshape(images_np.shape[0], -1).min(axis=1).reshape(-1, 1, 1, 1)\n",
    "        )\n",
    "        max_ = (\n",
    "            images_np.reshape(images_np.shape[0], -1).max(axis=1).reshape(-1, 1, 1, 1)\n",
    "        )\n",
    "        images_np = (images_np - min_) / (max_ - min_)\n",
    "\n",
    "        # Draw bounding boxes on images\n",
    "        images_np_with_box = np.clip(255 * images_np + 0.5, a_min=0, a_max=255).astype(\n",
    "            np.uint8\n",
    "        )\n",
    "        for idx, rects in rect_positions.items():\n",
    "            if rects:\n",
    "                img = cv2.cvtColor(images_np_with_box[idx].copy(), cv2.COLOR_BGR2RGB)\n",
    "                for x, y, w, h in rects:\n",
    "                    img = cv2.rectangle(\n",
    "                        img, (x, y), (x + w, y + h), color_yellow, width_box\n",
    "                    )\n",
    "                    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n",
    "                    images_np_with_box[idx] = img\n",
    "        images_np_with_box = images_np_with_box.astype(np.float32) / 255.0\n",
    "\n",
    "    return images_np_with_box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f17d622-257c-440a-8787-5075843b775a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Drawing prototypes\n",
    "\"\"\"\n",
    "\n",
    "batch_size = 5\n",
    "\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "save_examples = f\"examples/{model_name}_{dataset_name}/prototypes\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "for i in tqdm(\n",
    "    (tmp := np.arange(0, last_layer.in_features, step=batch_size, dtype=int)),\n",
    "    total=len(tmp),\n",
    "    desc=\"Progress prototypes\",\n",
    "    position=0,\n",
    "    ncols=100,\n",
    "):\n",
    "    idx_channels = range(i, min(i + batch_size, last_layer.in_features))\n",
    "    prototypes_indices_img = data_prototypes[\"prototypes_indices_img\"][:, idx_channels]\n",
    "    prototypes_values_img = data_prototypes[\"prototypes_values\"][:, idx_channels]\n",
    "\n",
    "    list_prototypes = []\n",
    "    for j, channel in enumerate(idx_channels):\n",
    "        images = []\n",
    "        targets = []\n",
    "        for idx_ in prototypes_indices_img[prototypes_values_img[:, j] > 0, j]:\n",
    "            image, target = dataset[idx_]\n",
    "            images.append(image)\n",
    "            targets.append(target)\n",
    "        images = torch.stack(images)\n",
    "        try:\n",
    "            list_prototypes.append(\n",
    "                generate_heatmap_and_boxes(model, images, targets, last_layer, channel)\n",
    "            )\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "\n",
    "    list_prototypes = np.concatenate(list_prototypes, axis=0)\n",
    "    arr_grid = make_grid(\n",
    "        list_prototypes,\n",
    "        nrow=len(idx_channels),\n",
    "        padding=2,\n",
    "        normalize=False,\n",
    "        pad_value=1.0,\n",
    "    )\n",
    "    arr_grid = np.clip(255 * arr_grid + 0.5, 0, 255).astype(np.uint8)\n",
    "\n",
    "    # Save the image\n",
    "    im = Image.fromarray(arr_grid)\n",
    "    im.save(\n",
    "        f\"{save_examples}/prototypes_channels-{'-'.join(map(str, idx_channels))}.png\"\n",
    "    )\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e8317c-e8cd-4d2e-a40f-eb2facc3cb91",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a19079-e120-4431-8a9c-59f59f41f1e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Drawing images from each class and their prototypes for the most active channels\n",
    "\"\"\"\n",
    "\n",
    "# Parameters\n",
    "num_choice_images_per_class = 1\n",
    "save_examples = f\"examples/{model_name}_{dataset_name}/image_prototypes\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Ensure deterministic behavior\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "for idx_class in tqdm(\n",
    "    data2save.keys(),\n",
    "    desc=\"Progress classes\",\n",
    "    total=len(data2save),\n",
    "    position=0,\n",
    "    ncols=100,\n",
    "):\n",
    "    for i in np.random.choice(\n",
    "        (tmp := len(data2save[idx_class][\"idx_img_correct_label\"])),\n",
    "        size=num_choice_images_per_class,\n",
    "        replace=True if tmp < num_choice_images_per_class else False,\n",
    "    ):\n",
    "        idx_img = data2save[idx_class][\"idx_img_correct_label\"][i]\n",
    "        idx_channels = data2save[idx_class][\"idx_important_channels\"][i][\n",
    "            data2save[idx_class][\"values_channels\"][i] > 0\n",
    "        ]\n",
    "\n",
    "        prototypes_indices_img = data_prototypes[\"prototypes_indices_img\"][\n",
    "            :, idx_channels\n",
    "        ]\n",
    "        prototypes_values_img = data_prototypes[\"prototypes_values\"][:, idx_channels]\n",
    "\n",
    "        image, target = dataset_test[idx_img]\n",
    "        image = image.permute(1, 2, 0).detach().cpu().numpy()\n",
    "        min_, max_ = np.min(image), np.max(image)\n",
    "        image = (image - min_) / (max_ - min_)\n",
    "\n",
    "        arr_grid = np.clip(255 * image + 0.5, 0, 255).astype(np.uint8)\n",
    "        im = Image.fromarray(arr_grid)\n",
    "        im.save(f\"{save_examples}/img-{idx_img}_class-{idx_class}.png\")\n",
    "\n",
    "        list_prototypes = []\n",
    "        for j, channel in enumerate(idx_channels):\n",
    "            img, target = dataset_test[idx_img]\n",
    "            images = [img]\n",
    "            targets = [target]\n",
    "            for idx_ in prototypes_indices_img[prototypes_values_img[:, j] > 0, j]:\n",
    "                img, target = dataset[idx_]\n",
    "                images.append(img)\n",
    "                targets.append(target)\n",
    "            images = torch.stack(images)\n",
    "            try:\n",
    "                list_prototypes.append(\n",
    "                    generate_heatmap_and_boxes(\n",
    "                        model, images, targets, last_layer, channel\n",
    "                    )\n",
    "                )\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "\n",
    "        list_prototypes = np.concatenate(list_prototypes, axis=0)\n",
    "        arr_grid = make_grid(\n",
    "            list_prototypes,\n",
    "            nrow=len(idx_channels),\n",
    "            padding=2,\n",
    "            col_padding=30,\n",
    "            normalize=False,\n",
    "            pad_value=1.0,\n",
    "        )\n",
    "        arr_grid = np.clip(255 * arr_grid + 0.5, 0, 255).astype(np.uint8)\n",
    "        im = Image.fromarray(arr_grid)\n",
    "        im.save(\n",
    "            f\"{save_examples}/img-{idx_img}_prototypes_channels-{'-'.join(map(str, idx_channels))}.png\"\n",
    "        )\n",
    "\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1b3460fb65b6f2f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03db930a-dd8d-464a-98bf-48ea3aab9157",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Selected the intersection of relevant channels for all images from a given class and draw class prototypes for these channels\n",
    "\"\"\"\n",
    "\n",
    "type_of_selection = \"intersection\"\n",
    "# type_of_selection = \"frequency\"\n",
    "\n",
    "save_examples = f\"examples/{model_name}_{dataset_name}/class_prototypes\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "for idx_class in tqdm(\n",
    "    data2save.keys(),\n",
    "    desc=\"Progress classes\",\n",
    "    total=len(data2save),\n",
    "    position=0,\n",
    "    ncols=100,\n",
    "):\n",
    "    # Determine relevant channels based on type_of_selection\n",
    "    channels = data2save[idx_class][\"idx_important_channels\"].copy()\n",
    "    channels[data2save[idx_class][\"values_channels\"] < 0] = -1\n",
    "\n",
    "    if type_of_selection == \"intersection\":\n",
    "        relevant_channels = reduce(np.intersect1d, channels)\n",
    "    elif type_of_selection == \"frequency\":\n",
    "        num = channels.shape[0]\n",
    "        unique, counts = np.unique(channels.flatten(), return_counts=True)\n",
    "        counts = dict(\n",
    "            sorted(\n",
    "                dict(zip(unique, counts)).items(),\n",
    "                key=lambda item: item[1],\n",
    "                reverse=True,\n",
    "            )\n",
    "        )\n",
    "        relevant_channels = [\n",
    "            key for key, val in counts.items() if val / num > threshold\n",
    "        ]\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            \"You can use only 'intersection' or 'frequency' type of selection\"\n",
    "        )\n",
    "\n",
    "    prototypes_indices_img = data_prototypes[\"prototypes_indices_img\"][\n",
    "        :, relevant_channels\n",
    "    ]\n",
    "    prototypes_values_img = data_prototypes[\"prototypes_values\"][:, relevant_channels]\n",
    "\n",
    "    # Process prototypes for relevant channels\n",
    "    list_prototypes = []\n",
    "    for j, channel in enumerate(relevant_channels):\n",
    "        images = []\n",
    "        targets = []\n",
    "        for idx_ in prototypes_indices_img[prototypes_values_img[:, j] > 0, j]:\n",
    "            image, target = dataset[idx_]\n",
    "            images.append(image)\n",
    "            targets.append(target)\n",
    "        images = torch.stack(images)\n",
    "\n",
    "        # Generate heatmap and process images with bounding boxes\n",
    "        try:\n",
    "            images_with_boxes = generate_heatmap_and_boxes(\n",
    "                model, images, targets, last_layer, channel\n",
    "            )\n",
    "\n",
    "            list_prototypes.append(images_with_boxes)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "\n",
    "    # Concatenate and save prototypes for the class if they exist\n",
    "    if list_prototypes:\n",
    "        list_prototypes = np.concatenate(list_prototypes, axis=0)\n",
    "        arr_grid = make_grid(\n",
    "            list_prototypes,\n",
    "            nrow=len(relevant_channels),\n",
    "            padding=2,\n",
    "            normalize=False,\n",
    "            interval=None,\n",
    "            scale_each=False,\n",
    "            pad_value=1.0,\n",
    "        )\n",
    "        arr_grid = np.clip(255 * arr_grid + 0.5, a_min=0, a_max=255).astype(np.uint8)\n",
    "        im = Image.fromarray(arr_grid)\n",
    "        im.save(\n",
    "            f\"{save_examples}/class-{idx_class}_prototypes_channels-{'-'.join(map(str, relevant_channels))}.png\"\n",
    "        )\n",
    "    else:\n",
    "        warning_message = f\"Class {idx_class} does not have prototypes\"\n",
    "        warnings.warn(warning_message)\n",
    "\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b57b067-db52-4041-ae63-3489f6098b6f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49a6a13-b265-4743-ad24-6af4bbd0aac2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Ex. III. Effect similar to Grad-CAM",
   "id": "32fc18e2-8c98-41e6-a831-b5c9f23f89c4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0981f9-1ec6-45ee-9304-79aa3420e936",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Selecting images that have consistent predictions in all models for the target\n",
    "\"\"\"\n",
    "\n",
    "save_examples = f\"examples\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "batch_size = 10\n",
    "if not Path(\n",
    "    f\"{save_examples}/correct_pred_images_for_all_models_{dataset_name}.npz\"\n",
    ").is_file():\n",
    "    data2save = {i: None for i in range(len(dataset_test))}\n",
    "\n",
    "    images_indeces = range(len(dataset_test))\n",
    "    for model_name, model in models.items():\n",
    "        if not model_name.startswith(\"own_\"):\n",
    "            continue\n",
    "\n",
    "        images = []\n",
    "        targets = []\n",
    "        indeces = []\n",
    "        images_indeces = list(data2save.keys())\n",
    "        with torch.inference_mode():\n",
    "            for idx_image in tqdm(\n",
    "                images_indeces,\n",
    "                total=len(images_indeces),\n",
    "                desc=\"Progress\",\n",
    "                position=0,\n",
    "                ncols=100,\n",
    "            ):\n",
    "                if len(images) < batch_size:\n",
    "                    image, target = dataset_test[idx_image]\n",
    "                    if data2save[idx_image] is None:\n",
    "                        data2save[idx_image] = target\n",
    "                    images.append(image)\n",
    "                    targets.append(target)\n",
    "                    indeces.append(idx_image)\n",
    "\n",
    "                if len(images) == batch_size:\n",
    "                    images = torch.stack(images)\n",
    "                    targets = torch.tensor(targets)\n",
    "\n",
    "                    images = images.to(device, non_blocking=True)\n",
    "                    targets = targets.to(device, non_blocking=True)\n",
    "\n",
    "                    logit, _ = model(images)\n",
    "                    y = torch.argmax(logit, dim=1)\n",
    "\n",
    "                    idx = targets != y\n",
    "                    if idx.sum().item():\n",
    "                        indeces = np.array(indeces)[idx.detach().cpu().numpy()]\n",
    "                        for idx in indeces:\n",
    "                            del data2save[idx]\n",
    "\n",
    "                    images = []\n",
    "                    targets = []\n",
    "                    indeces = []\n",
    "\n",
    "    print(f\"\\nThe number of well-predicted image labels: {len(data2save.keys())}\")\n",
    "\n",
    "    correct_pred_images_all_models = {}\n",
    "    for idx, cl in data2save.items():\n",
    "        try:\n",
    "            correct_pred_images_all_models[f\"{cl}\"].append(idx)\n",
    "        except KeyError:\n",
    "            correct_pred_images_all_models[f\"{cl}\"] = [idx]\n",
    "\n",
    "    # Save to an NPZ file\n",
    "    np.savez(\n",
    "        f\"{save_examples}/correct_pred_images_for_all_models_{dataset_name}.npz\",\n",
    "        **correct_pred_images_all_models,\n",
    "    )\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "954b32a5-5fe9-4b29-b688-70b7b73ce608",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Selected a few images from all classes and display \n",
    "\"\"\"\n",
    "\n",
    "data2save = np.load(\n",
    "    f\"{save_examples}/correct_pred_images_for_all_models_{dataset_name}.npz\"\n",
    ")\n",
    "data2save = dict(data2save)\n",
    "\n",
    "num_choice_images_per_class = 1\n",
    "images_indeces = [\n",
    "    np.random.choice(\n",
    "        idx_imgs,\n",
    "        size=num_choice_images_per_class,\n",
    "        replace=True if (tmp := len(idx_imgs)) < num_choice_images_per_class else False,\n",
    "    ).tolist()\n",
    "    for cl, idx_imgs in data2save.items()\n",
    "]\n",
    "images_indeces = np.array(images_indeces).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "897d98fc-e8a7-46dd-8623-7d71eceb307a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Displays image, heatmap, and their combinations for all models\n",
    "\"\"\"\n",
    "\n",
    "save_examples = f\"examples/{dataset_name}_modified_grad-CAM\"\n",
    "Path(save_examples).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "alpha = 0.5\n",
    "threshold = 0.0\n",
    "batch_size = 10\n",
    "\n",
    "# Create a custom color map\n",
    "colors_bg_white = [\n",
    "    (0, \"#0652ff\"),\n",
    "    (0.48, \"white\"),\n",
    "    (0.52, \"white\"),\n",
    "    # (1, '#b22c00')\n",
    "    (1, \"#d11919\"),\n",
    "]\n",
    "colors_bg_black = [\n",
    "    (0, \"#0537ff\"),\n",
    "    (0.48, \"black\"),\n",
    "    (0.52, \"black\"),\n",
    "    # (1, '#ff0000')\n",
    "    (1, \"#ff050d\"),\n",
    "]\n",
    "custom_cmap_bg_white = LinearSegmentedColormap.from_list(\n",
    "    \"custom_bg_white\", colors_bg_white\n",
    ")\n",
    "custom_cmap_bg_black = LinearSegmentedColormap.from_list(\n",
    "    \"custom_bg_black\", colors_bg_black\n",
    ")\n",
    "\n",
    "images = []\n",
    "targets = []\n",
    "indeces = []\n",
    "for model_name, model in models.items():\n",
    "    if not model.__class__.__name__.startswith(\"Custom\"):\n",
    "        continue\n",
    "\n",
    "    i = 0\n",
    "    with torch.inference_mode():\n",
    "        for idx_image in (\n",
    "            bar := tqdm(\n",
    "                images_indeces,\n",
    "                total=len(images_indeces),\n",
    "                desc=\"Progress\",\n",
    "                position=0,\n",
    "                ncols=100,\n",
    "            )\n",
    "        ):\n",
    "            if len(images) < batch_size:\n",
    "                image, target = dataset_test[idx_image]\n",
    "                images.append(image)\n",
    "                targets.append(target)\n",
    "                indeces.append(idx_image)\n",
    "\n",
    "            if len(images) == batch_size:\n",
    "                # raise ValueError(images, targets)\n",
    "                images = torch.stack(images)\n",
    "                targets = torch.tensor(targets)\n",
    "\n",
    "                images = images.to(device, non_blocking=True)\n",
    "                targets = targets.to(device, non_blocking=True)\n",
    "\n",
    "                logit, heatmap = model(images)\n",
    "                y = torch.argmax(logit, dim=1)\n",
    "\n",
    "                # Generate heatmap\n",
    "                if isinstance(model, ResNet):\n",
    "                    last_layer = model.fc\n",
    "                elif isinstance(model, DenseNet):\n",
    "                    last_layer = model.classifier\n",
    "                elif isinstance(model, (ConvNeXt, MaxVit)):\n",
    "                    last_layer = model.classifier[-1]\n",
    "                elif isinstance(model, VisionTransformer):\n",
    "                    last_layer = model.heads.head\n",
    "                elif isinstance(model, SwinTransformer):\n",
    "                    last_layer = model.head\n",
    "                else:\n",
    "                    raise NotImplementedError(f\"Model '{model_name}' not implemented\")\n",
    "\n",
    "                if isinstance(model, MaxVit):\n",
    "                    heatmap = model.classifier[:-1](heatmap).unsqueeze(-1).unsqueeze(\n",
    "                        -1\n",
    "                    ) * heatmap.div(heatmap.sum(dim=[-2, -1], keepdim=True))\n",
    "\n",
    "                weights = torch.abs(last_layer.weight[targets])\n",
    "                heatmap = torch.einsum(\n",
    "                    (\n",
    "                        \"bi,bai->ba\"\n",
    "                        if isinstance(model, VisionTransformer)\n",
    "                        else \"bi,biwh->bwh\"\n",
    "                    ),\n",
    "                    weights,\n",
    "                    heatmap,\n",
    "                )\n",
    "                if last_layer.bias is not None:\n",
    "                    bias_view_shape = (\n",
    "                        (-1, 1) if isinstance(model, VisionTransformer) else (-1, 1, 1)\n",
    "                    )\n",
    "                    heatmap = heatmap + last_layer.bias[targets].view(*bias_view_shape)\n",
    "\n",
    "                # Assert model output matches heatmap values\n",
    "                smallest_atol = find_smallest_atol(\n",
    "                    logit[range(images.shape[0]), targets],\n",
    "                    (\n",
    "                        heatmap.mean(dim=-1)\n",
    "                        if isinstance(model, VisionTransformer)\n",
    "                        else heatmap.sum(dim=[-2, -1]).div(np.prod(heatmap.shape[-2:]))\n",
    "                    ),\n",
    "                )\n",
    "                assert (\n",
    "                    smallest_atol is not None\n",
    "                ), \"Tensors are not close with any atol in the specified range.\"\n",
    "                # print(f\"Tensors are close with atol=1e{smallest_atol}\")\n",
    "\n",
    "                if isinstance(model, VisionTransformer):\n",
    "                    _, _, h, w = images.shape\n",
    "                    n_h = h // model.patch_size\n",
    "                    n_w = w // model.patch_size\n",
    "                    heatmap = heatmap[:, 1:].view(-1, n_h, n_w)\n",
    "\n",
    "                # Normalize the heatmap for visualization\n",
    "                heatmap = heatmap / torch.flatten(\n",
    "                    heatmap, start_dim=1, end_dim=-1\n",
    "                ).abs().max(dim=-1).values.view(\n",
    "                    -1, 1, 1\n",
    "                )  # changing value to [-1, 1]\n",
    "                heatmap = torch.div(heatmap, 2).add(0.5)  # changing value to [0, 1]\n",
    "\n",
    "                # Resize to image size\n",
    "                heatmap_np = (\n",
    "                    torch.nn.functional.interpolate(\n",
    "                        heatmap.unsqueeze(1),\n",
    "                        images.shape[-2:],\n",
    "                        mode=\"bilinear\",\n",
    "                    )\n",
    "                    .squeeze(1)\n",
    "                    .detach()\n",
    "                    .cpu()\n",
    "                    .numpy()\n",
    "                )\n",
    "\n",
    "                # Convert input images to numpy array and normalize them\n",
    "                images_np = images.permute(0, 2, 3, 1).detach().cpu().numpy()\n",
    "                min_ = np.min(\n",
    "                    images_np.reshape(images_np.shape[0], -1), axis=1\n",
    "                ).reshape(-1, 1, 1, 1)\n",
    "                max_ = np.max(\n",
    "                    images_np.reshape(images_np.shape[0], -1), axis=1\n",
    "                ).reshape(-1, 1, 1, 1)\n",
    "                images_np = (images_np - min_) / (max_ - min_)\n",
    "\n",
    "                # save to images\n",
    "                num = images_np.shape[0]\n",
    "                arr = np.empty((3 * num, *images_np.shape[1:]))\n",
    "                arr[:num, ...] = images_np\n",
    "                arr[num : 2 * num, ...] = custom_cmap_bg_white(heatmap_np)[..., :3]\n",
    "\n",
    "                # Blend the heatmap with the original image using transparency\n",
    "                arr[2 * num : 3 * num, ...] = (\n",
    "                    alpha * custom_cmap_bg_black(heatmap_np)[..., :3]\n",
    "                    + (1 - alpha) * images_np\n",
    "                )\n",
    "\n",
    "                idx = np.arange(0, 3 * num, step=num, dtype=int)\n",
    "                for j in range(num):\n",
    "                    arr_grid = make_grid(\n",
    "                        arr[idx + j],\n",
    "                        # nrow=num,\n",
    "                        nrow=3,\n",
    "                        padding=2,\n",
    "                        normalize=False,\n",
    "                        interval=None,\n",
    "                        scale_each=False,\n",
    "                        pad_value=0.5,\n",
    "                    )\n",
    "                    arr_grid = np.clip(255 * arr_grid + 0.5, a_min=0, a_max=255).astype(\n",
    "                        np.uint8\n",
    "                    )\n",
    "\n",
    "                    # Save the image\n",
    "                    im = Image.fromarray(arr_grid)\n",
    "                    im.save(f\"{save_examples}/{indeces[j]}_{model_name}.png\")\n",
    "\n",
    "                i += 1\n",
    "\n",
    "                images = []\n",
    "                targets = []\n",
    "                indeces = []\n",
    "\n",
    "                bar.set_description(f\"{model_name} | epsilon=1e{smallest_atol}\")\n",
    "\n",
    "                # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd74b48f-20b9-405d-80f5-dd99b2551d21",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
