{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c7732c23",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.029929218Z",
     "start_time": "2025-05-21T16:07:29.988436338Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display\n",
    "from src.datamodules.pets import PetsDataModule\n",
    "from src.datamodules.cars import CarsDataModule\n",
    "from src.datamodules.imagenette import ImagenetteDataModule\n",
    "from src.models.patch_importance import PatchImportance\n",
    "from src.models.patch_importance_cnn import PatchImportanceCNN\n",
    "from src.models.resnet_importance import ResNetImportanceModel\n",
    "from src.visualize.patch_importance import visualize_patch_importance\n",
    "from src.datamodules.imagenet import ImagenetDataModule\n",
    "import torch\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8885ebcc-dcc6-4525-9d93-ceb54bf37d10",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.103802034Z",
     "start_time": "2025-05-21T16:07:30.014152247Z"
    }
   },
   "outputs": [],
   "source": [
    "active_datamodule = None\n",
    "last_datamodule_key = None\n",
    "ba_model = None\n",
    "fa_model = None\n",
    "ra_model = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d6f56295",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.115985033Z",
     "start_time": "2025-05-21T16:07:30.089678580Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset = widgets.Dropdown(\n",
    "    options=[('Oxford Pets', 'pets'), ('Stanford Cars', 'cars'), ('Imagenette', 'imagenette'), ('Imagenet', 'imagenet')],\n",
    "    value='pets',\n",
    "    description='Dataset:'\n",
    ")\n",
    "\n",
    "patch_size = widgets.Dropdown(\n",
    "    options=[4, 8, 16, 32, 56, 112],\n",
    "    value=16,\n",
    "    description='Patch size:',\n",
    ")\n",
    "\n",
    "embedding_size = widgets.Dropdown(\n",
    "    options=[16, 32, 64, 128, 256, 512],\n",
    "    value=128,\n",
    "    description='Ebedding size:',\n",
    ")\n",
    "    \n",
    "image_slider = widgets.IntSlider(\n",
    "    value=0,\n",
    "    min=0,\n",
    "    max=0,\n",
    "    step=1,\n",
    "    description='Test index:',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='d'\n",
    ")\n",
    "\n",
    "threshold_slider = widgets.FloatSlider(\n",
    "    value=0,\n",
    "    min=0,\n",
    "    max=1,\n",
    "    step=0.01,\n",
    "    description='Threshold:',\n",
    "    disabled=False,\n",
    "    continuous_update=True,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='.2f'\n",
    ")\n",
    "\n",
    "status_label = widgets.Label(value=\"\")\n",
    "output = widgets.Output()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9685ba4f-237e-4c26-be51-b30ca0d420ee",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.163764049Z",
     "start_time": "2025-05-21T16:07:30.102268555Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_datamodule(name):\n",
    "    if name == 'pets':\n",
    "        return PetsDataModule(batch_size=3, target_type='category')\n",
    "    elif name == 'cars':\n",
    "        return CarsDataModule(batch_size=3)\n",
    "    elif name == 'imagenette':\n",
    "        return ImagenetteDataModule(batch_size=3)\n",
    "    elif name == 'imagenet':\n",
    "        return ImagenetDataModule(data_dir=os.environ['IMAGENET_PATH'], batch_size=3)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid dataset\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "68646805",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.199912221Z",
     "start_time": "2025-05-21T16:07:30.153710040Z"
    }
   },
   "outputs": [],
   "source": [
    "def load_models(change=None):\n",
    "    global ba_model, fa_model, ra_model\n",
    "    ckpt_path = f'checkpoints/{dataset.value}/checkpoint_'\n",
    "    try:\n",
    "        ra_model = ResNetImportanceModel.load_from_checkpoint(\n",
    "            f'{ckpt_path}ra_{embedding_size.value}_{patch_size.value}.ckpt', strict=False\n",
    "        ).eval().cpu()\n",
    "        threshold_slider.value = ra_model.threshold\n",
    "    except Exception as e:\n",
    "        ra_model = None\n",
    "        threshold_slider.value = 0\n",
    "        \n",
    "    try:\n",
    "        ba_model = PatchImportance.load_from_checkpoint(\n",
    "            f'{ckpt_path}ba_{embedding_size.value}_{patch_size.value}.ckpt', strict=False\n",
    "        ).eval().cpu()\n",
    "    except Exception as e:\n",
    "        ba_model = None\n",
    "    \n",
    "    try:\n",
    "        fa_model = PatchImportanceCNN.load_from_checkpoint(\n",
    "            f'{ckpt_path}fa_{embedding_size.value}_{patch_size.value}.ckpt', strict=False\n",
    "        ).eval().cpu()\n",
    "    except Exception as e:\n",
    "        fa_model = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "222b01ad-e981-4f10-8753-8eddfccd5f37",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.252983968Z",
     "start_time": "2025-05-21T16:07:30.192859344Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_avaialable_params(dataset_name):\n",
    "    patch_sizes = []\n",
    "    embedding_sizes = []\n",
    "    for filename in os.listdir(f'checkpoints/{dataset_name}'):\n",
    "        filename_split = filename.split('_')\n",
    "        patch_sizes.append(filename_split[3].split('.')[0])\n",
    "        embedding_sizes.append(filename_split[2])\n",
    "    patch_size.options = set(patch_sizes)\n",
    "    embedding_size.options = set(embedding_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6610526-5a4e-4e2d-8b8a-92615d1d881e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.276947554Z",
     "start_time": "2025-05-21T16:07:30.227183943Z"
    }
   },
   "outputs": [],
   "source": [
    "def load_selected_datamodule(change=None):\n",
    "    global active_datamodule, last_datamodule_key\n",
    "    if dataset.value != last_datamodule_key:\n",
    "        status_label.value = \"🔄 Loading datamodule...\"\n",
    "        try:\n",
    "            dm = get_datamodule(dataset.value)\n",
    "            dm.prepare_data()\n",
    "            dm.setup()\n",
    "            active_datamodule = dm\n",
    "            last_datamodule_key = dataset.value\n",
    "            image_slider.max = len(dm.test_data) - 1\n",
    "            image_slider.value = 0\n",
    "            get_avaialable_params(dataset.value)\n",
    "            load_models()\n",
    "            run_inference()\n",
    "            status_label.value = f\"✅ Datamodule loaded.\"\n",
    "        except Exception as e:\n",
    "            status_label.value = f\"⚠️ Error loading datamodule: {e}\"\n",
    "            active_datamodule = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9307bd0d-83d4-4b90-9b57-c6f0411bba96",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.317539555Z",
     "start_time": "2025-05-21T16:07:30.265678112Z"
    }
   },
   "outputs": [],
   "source": [
    "def run_inference(change=None):\n",
    "    output.clear_output()\n",
    "    status_label.value = \"🔄 Running inference...\"\n",
    "    with output:\n",
    "        if not active_datamodule:\n",
    "            status_label.value = \"⚠️ Datamodule not loaded.\"\n",
    "            return\n",
    "        \n",
    "        batch = active_datamodule.test_data[image_slider.value]\n",
    "\n",
    "        x, y = batch['x'].unsqueeze(0), batch['y']\n",
    "\n",
    "        fig, axs = plt.subplots(1, 3, figsize=(8, 4))\n",
    "        for i in range(3):\n",
    "            axs[i].axis('off')\n",
    "            \n",
    "        if ba_model is not None:\n",
    "            pred = ba_model(x)\n",
    "            y_pred = ba_model.strategy.process_outputs(pred['logits'])\n",
    "            images, original = visualize_patch_importance(x, pred['importance'], patch_size=ba_model.patch_size)\n",
    "            axs[0].imshow(original[0])\n",
    "            axs[0].set_title(f'True class: {y}')\n",
    "            axs[1].imshow(images[0])\n",
    "            axs[1].set_title(f'BA pred: {y_pred[0]}\\n')\n",
    "\n",
    "        if fa_model is not None:\n",
    "            pred = fa_model(x)\n",
    "            y_pred = fa_model.strategy.process_outputs(pred['logits'])\n",
    "            images, original = visualize_patch_importance(x, pred['importance'], patch_size=ba_model.patch_size)\n",
    "            axs[0].imshow(original[0])\n",
    "            axs[0].set_title(f'True class: {y}')\n",
    "            current_title = axs[1].get_title()\n",
    "            axs[1].set_title(f'{current_title}FA pred: {y_pred[0]}')\n",
    "\n",
    "        if ra_model is not None:\n",
    "            pred = ra_model(x)\n",
    "            y_pred = ra_model.strategy.process_outputs(pred['logits'])\n",
    "            images, original = visualize_patch_importance(\n",
    "                x, pred['importance'], threshold=threshold_slider.value, color=(0, 0, 0), patch_size=ra_model.importance.patch_size)\n",
    "            axs[2].imshow(images[0])\n",
    "            if threshold_slider.value == ra_model.threshold:\n",
    "                axs[2].set_title(f'RA pred: {y_pred[0]}')\n",
    "            else:\n",
    "                axs[2].set_title(f'Threshold: {threshold_slider.value}')\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        status_label.value = \"✅ Inference complete.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "19c81e7c-0d50-47f5-aa58-5d1e754c78f3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:30.331414407Z",
     "start_time": "2025-05-21T16:07:30.304231948Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset.observe(load_selected_datamodule, names='value')\n",
    "patch_size.observe(load_models, names='value')\n",
    "embedding_size.observe(load_models, names='value')\n",
    "threshold_slider.observe(run_inference, names='value')\n",
    "image_slider.observe(run_inference, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "808dde71",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-21T16:07:53.108604163Z",
     "start_time": "2025-05-21T16:07:30.344811733Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e3929fed4874c2cac66ba5a74ca5a78",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Dropdown(description='Dataset:', options=(('Oxford Pets', 'pets'), ('Stanford Cars', 'cars'), (…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ui = widgets.VBox([dataset, patch_size, embedding_size, threshold_slider, image_slider, status_label, output])\n",
    "display(ui)\n",
    "\n",
    "load_selected_datamodule()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
