{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n",
    "import seaborn as sns\n",
    "\n",
    "from resnet_spectral import ResNet\n",
    "from interpretable_resnet_spectral import InterpretableResNet\n",
    "from torch.utils.data import Dataset, DataLoader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load(\"../ECG-Spectral/Spectral/two_class_resnet/simple_binary_dataset.npy\")\n",
    "labels = np.load(\"../ECG-Spectral/Spectral/two_class_resnet/simple_binary_labels.npy\")\n",
    "data.shape, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved 8065 files to ../ECG-Spectral/Spectral/two_class_resnet/two_class_spectra\n"
     ]
    }
   ],
   "source": [
    "# Create the output directory\n",
    "output_dir = \"../ECG-Spectral/Spectral/two_class_resnet/two_class_spectra\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# Loop through data and labels to save each instance individually\n",
    "for idx in range(len(data)):\n",
    "    filename = f\"two_class_file_{idx}_class_{labels[idx]}.npy\"\n",
    "    filepath = os.path.join(output_dir, filename)\n",
    "    np.save(filepath, data[idx])\n",
    "\n",
    "print(f\"Saved {len(data)} files to {output_dir}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Testing model performance of two class dataset\n",
    "class TwoClassData(Dataset):\n",
    "    def __init__(self, data_path):\n",
    "        self.data = []\n",
    "        self.labels = []\n",
    "\n",
    "        for file in os.listdir(data_path):\n",
    "            if file.endswith(\".npy\"):\n",
    "                self.data.append(torch.from_numpy(np.load(os.path.join(data_path, file), allow_pickle=True)).unsqueeze(0))\n",
    "                self.labels.append(int(file.split(\".npy\")[0].rsplit(\"_\",1)[1]))\n",
    "\n",
    "        self.data = torch.stack(self.data).float()\n",
    "        self.labels = torch.tensor(self.labels).float()\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], self.labels[idx]\n",
    "    \n",
    "data_path = \"../ECG-Spectral/Spectral/two_class_resnet/two_class_spectra\"\n",
    "dataset = TwoClassData(data_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1356])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (encoder): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (4): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (linear): Linear(in_features=4300, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataloader = DataLoader(dataset,\n",
    "                        batch_size=64,\n",
    "                        shuffle=False)\n",
    "\n",
    "model = ResNet()\n",
    "model.load_state_dict(torch.load(\"../ECG-Spectral/Spectral/two_class_resnet/resnet_binary_invitro.pt\"))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "InterpretableResNet(\n",
       "  (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (encoder): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (4): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential(\n",
       "          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)\n",
       "          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): ResidualBlock(\n",
       "        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (shortcut): Sequential()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (linear): Linear(in_features=4300, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "interp_model = InterpretableResNet()\n",
    "interp_model.load_state_dict(torch.load(\"../ECG-Spectral/Spectral/two_class_resnet/resnet_binary_invitro.pt\"))\n",
    "interp_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Comprehensive Model Evaluation\n",
    "def evaluate_model(model, dataloader, device='cuda:1', explanation_mode = False):\n",
    "    \"\"\"\n",
    "    Evaluate the model and calculate comprehensive performance metrics\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "    all_predictions = []\n",
    "    all_labels = []\n",
    "    all_probabilities = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch_data, batch_labels in tqdm(dataloader, desc=\"Evaluating\"):\n",
    "            batch_data = batch_data.to(device)\n",
    "            batch_labels = batch_labels.to(device)\n",
    "            \n",
    "            # Forward pass\n",
    "            if explanation_mode:\n",
    "                outputs = model(batch_data,\n",
    "                                explanation_mode = True,\n",
    "                                masking_value = -9999)\n",
    "            else:\n",
    "                outputs = model(batch_data)\n",
    "            probabilities = F.softmax(outputs, dim=1)\n",
    "            predictions = torch.argmax(outputs, dim=1)\n",
    "            \n",
    "            # Store results\n",
    "            all_predictions.extend(predictions.cpu().numpy())\n",
    "            all_labels.extend(batch_labels.cpu().numpy())\n",
    "            all_probabilities.extend(probabilities.cpu().numpy())\n",
    "    return np.array(all_predictions), np.array(all_labels), np.array(all_probabilities)\n",
    "\n",
    "def calculate_metrics(predictions, true_labels, probabilities):\n",
    "    # Calculate metrics\n",
    "    accuracy = accuracy_score(true_labels, predictions)\n",
    "    precision_macro = precision_score(true_labels, predictions, average='macro')\n",
    "    precision_weighted = precision_score(true_labels, predictions, average='weighted')\n",
    "    recall_macro = recall_score(true_labels, predictions, average='macro')\n",
    "    recall_weighted = recall_score(true_labels, predictions, average='weighted')\n",
    "    f1_macro = f1_score(true_labels, predictions, average='macro')\n",
    "    f1_weighted = f1_score(true_labels, predictions, average='weighted')\n",
    "\n",
    "    # Print overall metrics\n",
    "    print(\"=\" + \"=\"*50)\n",
    "    print(\"OVERALL PERFORMANCE METRICS\")\n",
    "    print(\"=\"*50)\n",
    "    print(f\"Accuracy: {accuracy:.4f}\")\n",
    "    print(f\"Precision (Macro): {precision_macro:.4f}\")\n",
    "    print(f\"Precision (Weighted): {precision_weighted:.4f}\")\n",
    "    print(f\"Recall (Macro): {recall_macro:.4f}\")\n",
    "    print(f\"Recall (Weighted): {recall_weighted:.4f}\")\n",
    "    print(f\"F1-Score (Macro): {f1_macro:.4f}\")\n",
    "    print(f\"F1-Score (Weighted): {f1_weighted:.4f}\")\n",
    "\n",
    "    # Get per-class metrics\n",
    "    precision_per_class = precision_score(true_labels, predictions, average=None)\n",
    "    recall_per_class = recall_score(true_labels, predictions, average=None)\n",
    "    f1_per_class = f1_score(true_labels, predictions, average=None)\n",
    "\n",
    "    print(\"=\" + \"=\"*50)\n",
    "    print(\"PER-CLASS PERFORMANCE METRICS\")\n",
    "    print(\"=\"*50)\n",
    "    class_names = ['Class 0', 'Class 1']\n",
    "    for i, class_name in enumerate(class_names):\n",
    "        print(f\"{class_name}:\")\n",
    "        print(f\"  Precision: {precision_per_class[i]:.4f}\")\n",
    "        print(f\"  Recall: {recall_per_class[i]:.4f}\")\n",
    "        print(f\"  F1-Score: {f1_per_class[i]:.4f}\")\n",
    "\n",
    "    # Detailed classification report\n",
    "    print(\"=\" + \"=\"*50)\n",
    "    print(\"DETAILED CLASSIFICATION REPORT\")\n",
    "    print(\"=\"*50)\n",
    "    print(classification_report(true_labels, predictions, target_names=class_names))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running model evaluation...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|██████████| 127/127 [00:01<00:00, 80.19it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===================================================\n",
      "OVERALL PERFORMANCE METRICS\n",
      "==================================================\n",
      "Accuracy: 0.9988\n",
      "Precision (Macro): 0.9988\n",
      "Precision (Weighted): 0.9988\n",
      "Recall (Macro): 0.9988\n",
      "Recall (Weighted): 0.9988\n",
      "F1-Score (Macro): 0.9988\n",
      "F1-Score (Weighted): 0.9988\n",
      "===================================================\n",
      "PER-CLASS PERFORMANCE METRICS\n",
      "==================================================\n",
      "Class 0:\n",
      "  Precision: 0.9975\n",
      "  Recall: 1.0000\n",
      "  F1-Score: 0.9988\n",
      "Class 1:\n",
      "  Precision: 1.0000\n",
      "  Recall: 0.9975\n",
      "  F1-Score: 0.9988\n",
      "===================================================\n",
      "DETAILED CLASSIFICATION REPORT\n",
      "==================================================\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "     Class 0       1.00      1.00      1.00      4048\n",
      "     Class 1       1.00      1.00      1.00      4017\n",
      "\n",
      "    accuracy                           1.00      8065\n",
      "   macro avg       1.00      1.00      1.00      8065\n",
      "weighted avg       1.00      1.00      1.00      8065\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Run evaluation\n",
    "print(\"Running model evaluation...\")\n",
    "predictions, true_labels, probabilities = evaluate_model(model, dataloader)\n",
    "calculate_metrics(predictions, true_labels, probabilities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running model evaluation...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating:   0%|          | 0/127 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|██████████| 127/127 [00:02<00:00, 51.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===================================================\n",
      "OVERALL PERFORMANCE METRICS\n",
      "==================================================\n",
      "Accuracy: 0.9988\n",
      "Precision (Macro): 0.9988\n",
      "Precision (Weighted): 0.9988\n",
      "Recall (Macro): 0.9988\n",
      "Recall (Weighted): 0.9988\n",
      "F1-Score (Macro): 0.9988\n",
      "F1-Score (Weighted): 0.9988\n",
      "===================================================\n",
      "PER-CLASS PERFORMANCE METRICS\n",
      "==================================================\n",
      "Class 0:\n",
      "  Precision: 0.9975\n",
      "  Recall: 1.0000\n",
      "  F1-Score: 0.9988\n",
      "Class 1:\n",
      "  Precision: 1.0000\n",
      "  Recall: 0.9975\n",
      "  F1-Score: 0.9988\n",
      "===================================================\n",
      "DETAILED CLASSIFICATION REPORT\n",
      "==================================================\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "     Class 0       1.00      1.00      1.00      4048\n",
      "     Class 1       1.00      1.00      1.00      4017\n",
      "\n",
      "    accuracy                           1.00      8065\n",
      "   macro avg       1.00      1.00      1.00      8065\n",
      "weighted avg       1.00      1.00      1.00      8065\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Run evaluation (AD)\n",
    "print(\"Running model evaluation...\")\n",
    "predictions, true_labels, probabilities = evaluate_model(interp_model, dataloader, explanation_mode = True)\n",
    "calculate_metrics(predictions, true_labels, probabilities)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "AD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
