{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5139aa9a",
   "metadata": {},
   "source": [
    "# setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7710649d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda\n"
     ]
    }
   ],
   "source": [
    "import os, math, time, random\n",
    "from dataclasses import dataclass, replace\n",
    "from typing import Tuple\n",
    "import copy\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR, ConstantLR\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "from torchvision import transforms, datasets, models\n",
    "from torchvision.datasets.utils import download_and_extract_archive\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(\"Device:\", device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b900f77f",
   "metadata": {},
   "outputs": [],
   "source": [
    "AUTO_DOWNLOAD_TINYIN = True\n",
    "\n",
    "@dataclass\n",
    "class TrainCfg:\n",
    "    name: str\n",
    "    dataset: str\n",
    "    num_classes: int\n",
    "    input_size: int\n",
    "    batch_size: int = 128\n",
    "    epochs: int = 200\n",
    "    optimizer: str = \"SGD\"\n",
    "    lr: float = 0.01\n",
    "    momentum: float = 0.9\n",
    "    weight_decay: float = 0.0\n",
    "    cosine: bool = True\n",
    "    warmup_epochs: int = 5\n",
    "    cifar_style_resnet: bool = False\n",
    "\n",
    "cfgs = {\n",
    "    # MNIST (1×28×28), FC baselines\n",
    "    \"FC2-M\":  TrainCfg(\"FC2\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-M\":  TrainCfg(\"FC5\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-M\": TrainCfg(\"FC12\", \"MNIST\",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # Fashion-MNIST (1×28×28), FC baselines\n",
    "    \"FC2-FM\":  TrainCfg(\"FC2\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-FM\":  TrainCfg(\"FC5\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-FM\": TrainCfg(\"FC12\", \"FashionMNIST\", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # CIFAR-10\n",
    "    \"FC5\":     TrainCfg(\"FC5\",    \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12\":    TrainCfg(\"FC12\",   \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"VGG16\":   TrainCfg(\"VGG16\",  \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),\n",
    "    \"AlexNet\": TrainCfg(\"AlexNet\",\"CIFAR10\", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224\n",
    "\n",
    "    # CIFAR-100\n",
    "    \"ResNet18_C100\": TrainCfg(\"ResNet18_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "    \"ResNet50_C100\": TrainCfg(\"ResNet50_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "\n",
    "    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)\n",
    "    \"ResNet18_TinyIN\": TrainCfg(\"ResNet18_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "    \"ResNet50_TinyIN\": TrainCfg(\"ResNet50_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "04cc0510",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FCNet(nn.Module):\n",
    "    def __init__(self, in_dim: int, widths):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(widths)\n",
    "        layers = []\n",
    "        for i in range(len(dims) - 2):\n",
    "            layers += [nn.Linear(dims[i], dims[i + 1]), nn.ReLU(inplace=True)]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(x, 1)\n",
    "        return self.net(x)\n",
    "\n",
    "def _infer_in_dim(cfg: TrainCfg) -> int:\n",
    "    c = 1 if cfg.dataset in (\"MNIST\", \"FashionMNIST\") else 3\n",
    "    return c * cfg.input_size * cfg.input_size\n",
    "\n",
    "def build_model(cfg: TrainCfg) -> nn.Module:\n",
    "    in_dim = _infer_in_dim(cfg)\n",
    "\n",
    "    if cfg.name == \"FC2\":\n",
    "        model = FCNet(in_dim, [100, cfg.num_classes])\n",
    "    elif cfg.name == \"FC5\":\n",
    "        model = FCNet(in_dim, [1000, 600, 300, 100, cfg.num_classes])\n",
    "    elif cfg.name == \"FC12\":\n",
    "        model = FCNet(in_dim, [1000, 900, 800, 750, 700, 650, 600, 500, 400, 200, 100, cfg.num_classes])\n",
    "\n",
    "    elif cfg.name == \"AlexNet\":\n",
    "        model = models.alexnet(weights=None)\n",
    "        if isinstance(model.classifier, nn.Sequential):\n",
    "            in_features = model.classifier[-1].in_features\n",
    "            model.classifier[-1] = nn.Linear(in_features, cfg.num_classes)\n",
    "\n",
    "    elif cfg.name == \"VGG16\":\n",
    "        model = models.vgg16(weights=None)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        model.classifier = nn.Linear(512, cfg.num_classes)\n",
    "\n",
    "    elif \"ResNet18\" in cfg.name:\n",
    "        model = models.resnet18(weights=None, num_classes=cfg.num_classes)\n",
    "        if cfg.cifar_style_resnet:\n",
    "            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "            model.maxpool = nn.Identity()\n",
    "\n",
    "    elif \"ResNet50\" in cfg.name:\n",
    "        model = models.resnet50(weights=None, num_classes=cfg.num_classes)\n",
    "        if cfg.cifar_style_resnet:\n",
    "            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "            model.maxpool = nn.Identity()\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model name {cfg.name}\")\n",
    "\n",
    "    return model.to(device)\n",
    "\n",
    "def count_params(model: nn.Module):\n",
    "    total = sum(p.numel() for p in model.parameters())\n",
    "    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    return total, trainable\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "47671b80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset statistics\n",
    "MNIST_MEAN,   MNIST_STD   = (0.1307,), (0.3081,)\n",
    "FASHION_MEAN, FASHION_STD = (0.2860,), (0.3530,)\n",
    "CIFAR10_MEAN, CIFAR10_STD   = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)\n",
    "CIFAR100_MEAN, CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)\n",
    "IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406),    (0.229, 0.224, 0.225)\n",
    "\n",
    "def _cifar_transforms(size: int, mean, std, resize_to_224: bool = False):\n",
    "    if size == 32 and not resize_to_224:\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.RandomCrop(32, padding=4),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "    else:\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.Resize(size),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.Resize(size),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "    return train_tfms, test_tfms\n",
    "\n",
    "def _gray_transforms(size: int, mean, std):\n",
    "    train_tfms = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean, std),\n",
    "    ])\n",
    "    test_tfms = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean, std),\n",
    "    ])\n",
    "    return train_tfms, test_tfms\n",
    "\n",
    "def get_dataloaders(cfg: TrainCfg, data_root: str = \"./data\") -> Tuple[DataLoader, DataLoader]:\n",
    "    if cfg.dataset == \"MNIST\":\n",
    "        train_tfms, test_tfms = _gray_transforms(28, MNIST_MEAN, MNIST_STD)\n",
    "        train_set = datasets.MNIST(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.MNIST(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"FashionMNIST\":\n",
    "        train_tfms, test_tfms = _gray_transforms(28, FASHION_MEAN, FASHION_STD)\n",
    "        train_set = datasets.FashionMNIST(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"CIFAR10\":\n",
    "        resize_to_224 = (cfg.input_size == 224)\n",
    "        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR10_MEAN, CIFAR10_STD, resize_to_224=resize_to_224)\n",
    "        train_set = datasets.CIFAR10(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"CIFAR100\":\n",
    "        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR100_MEAN, CIFAR100_STD, resize_to_224=False)\n",
    "        train_set = datasets.CIFAR100(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"TinyImageNet\":\n",
    "        train_dir = os.path.join(data_root, \"tiny-imagenet-200\", \"train\")\n",
    "        val_dir   = os.path.join(data_root, \"tiny-imagenet-200\", \"val\")\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.RandomResizedCrop(cfg.input_size),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.Resize(cfg.input_size + 8),\n",
    "            transforms.CenterCrop(cfg.input_size),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n",
    "        ])\n",
    "        train_set = datasets.ImageFolder(train_dir, transform=train_tfms)\n",
    "        test_set  = datasets.ImageFolder(val_dir,   transform=test_tfms)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown dataset {cfg.dataset}\")\n",
    "\n",
    "    pin = torch.cuda.is_available()\n",
    "    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,  num_workers=4, pin_memory=pin)\n",
    "    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=pin)\n",
    "    return train_loader, test_loader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "25302c03",
   "metadata": {},
   "outputs": [],
   "source": [
    "AUTO_DOWNLOAD_TINYIN = True\n",
    "\n",
    "@dataclass\n",
    "class TrainCfg:\n",
    "    name: str\n",
    "    dataset: str\n",
    "    num_classes: int\n",
    "    input_size: int\n",
    "    batch_size: int = 128\n",
    "    epochs: int = 200\n",
    "    optimizer: str = \"SGD\"\n",
    "    lr: float = 0.01\n",
    "    momentum: float = 0.9\n",
    "    weight_decay: float = 0.0\n",
    "    cosine: bool = True\n",
    "    warmup_epochs: int = 5\n",
    "    cifar_style_resnet: bool = False\n",
    "\n",
    "cfgs = {\n",
    "    # MNIST (1×28×28), FC baselines\n",
    "    \"FC2-M\":  TrainCfg(\"FC2\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-M\":  TrainCfg(\"FC5\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-M\": TrainCfg(\"FC12\", \"MNIST\",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # Fashion-MNIST (1×28×28), FC baselines\n",
    "    \"FC2-FM\":  TrainCfg(\"FC2\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-FM\":  TrainCfg(\"FC5\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-FM\": TrainCfg(\"FC12\", \"FashionMNIST\", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # CIFAR-10\n",
    "    \"FC5\":     TrainCfg(\"FC5\",    \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12\":    TrainCfg(\"FC12\",   \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"VGG16\":   TrainCfg(\"VGG16\",  \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),\n",
    "    \"AlexNet\": TrainCfg(\"AlexNet\",\"CIFAR10\", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224\n",
    "\n",
    "    # CIFAR-100\n",
    "    \"ResNet18_C100\": TrainCfg(\"ResNet18_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "    \"ResNet50_C100\": TrainCfg(\"ResNet50_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "\n",
    "    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)\n",
    "    \"ResNet18_TinyIN\": TrainCfg(\"ResNet18_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "    \"ResNet50_TinyIN\": TrainCfg(\"ResNet50_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "}\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate(model: nn.Module, loader: DataLoader, criterion):\n",
    "    model.eval()\n",
    "    total_loss, total_correct, n = 0.0, 0, 0\n",
    "    for x, y in loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        logits = model(x)\n",
    "        loss = criterion(logits, y)\n",
    "        total_loss += loss.item() * y.size(0)\n",
    "        total_correct += (logits.argmax(1) == y).sum().item()\n",
    "        n += y.size(0)\n",
    "    return total_loss / n, 100.0 * total_correct / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "57ae599e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def EMP_global_magnitude(model, beta=1.0):\n",
    "    \n",
    "    model1 = copy.deepcopy(model)\n",
    "    params = []\n",
    "    for m in model1.modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            params.append(m.weight.data.view(-1))\n",
    "    if not params:\n",
    "        return model, 0.0\n",
    "    \n",
    "    s = torch.cat(params)\n",
    "    x_norm = torch.abs(s) / torch.sum(torch.abs(s))\n",
    "    neff = 1/torch.sum((x_norm ** 2))\n",
    "    r_neff = torch.floor(beta * neff)\n",
    "    r_neff = r_neff.clamp(min=1, max=len(s)-1)\n",
    "    _, indices = torch.sort(torch.abs(s), descending=True)\n",
    "    thresh = torch.abs(s)[indices[int(r_neff)]]\n",
    "    \n",
    "    total, kept = 0, 0\n",
    "    for m in model.modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            w = m.weight.data\n",
    "            mask = (torch.abs(w) >= thresh)\n",
    "            kept += mask.sum().item()\n",
    "            total += mask.numel()\n",
    "            w.mul_(mask)\n",
    "    \n",
    "    sparsity = 1.0 - (kept / total)\n",
    "    return sparsity\n",
    "\n",
    "def EMP_loss_change(model: nn.Module, loader: DataLoader, beta: float = 1.0):\n",
    "    \"\"\"Taylor-based importance scoring with predefined threshold\"\"\"\n",
    "    model1 = copy.deepcopy(model)\n",
    "    model1.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    # Compute gradients on a batch\n",
    "    data_iter = iter(loader)\n",
    "    x, y = next(data_iter)\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    \n",
    "    model1.zero_grad()\n",
    "    outputs = model1(x)\n",
    "    loss = criterion(outputs, y)\n",
    "    loss.backward()\n",
    "    \n",
    "    # Calculate importance scores (gradient * weight)\n",
    "    scores = []\n",
    "    for m in model1.modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            if m.weight.grad is not None:\n",
    "                importance = torch.abs(m.weight.data * m.weight.grad)\n",
    "            else:\n",
    "                importance = torch.abs(m.weight.data)\n",
    "            scores.append(importance.view(-1))\n",
    "    \n",
    "    if not scores:\n",
    "        return model, 0.0\n",
    "    \n",
    "    all_scores = torch.cat(scores)\n",
    "    #print(len(all_scores))\n",
    "    w = torch.abs(all_scores)/torch.sum(torch.abs(all_scores))\n",
    "    neff = 1/torch.sum(w ** 2)\n",
    "    r_neff = torch.clamp(torch.floor(beta * neff), 1, len(all_scores)-1)\n",
    "    #print(r_neff)\n",
    "    threshold = torch.quantile(torch.abs(all_scores), r_neff / len(all_scores))\n",
    "    \n",
    "    # Apply mask\n",
    "    idx = 0\n",
    "    total, kept = 0, 0\n",
    "    for m in model1.modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            numel = m.weight.numel()\n",
    "            if m.weight.grad is not None:\n",
    "                importance = torch.abs(m.weight.data * m.weight.grad)\n",
    "            else:\n",
    "                importance = torch.abs(m.weight.data)\n",
    "            mask = importance > threshold\n",
    "            m.weight.data.mul_(mask)\n",
    "            kept += mask.sum().item()\n",
    "            total += numel\n",
    "    \n",
    "    actual_sparsity = 1.0 - (kept / total)\n",
    "    return model1, actual_sparsity\n",
    "\n",
    "\n",
    "def EMP_saliency(model: nn.Module, loader: DataLoader, beta: float = 1.0):\n",
    "    \"\"\"Gradient-based saliency with predefined threshold\"\"\"\n",
    "    model1 = copy.deepcopy(model)\n",
    "    model1.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    # Accumulate gradients over multiple batches for stability\n",
    "    num_batches = min(10, len(loader))\n",
    "    data_iter = iter(loader)\n",
    "    \n",
    "    # Initialize gradient accumulator\n",
    "    grad_accum = {}\n",
    "    for name, m in model1.named_modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            grad_accum[name] = torch.zeros_like(m.weight)\n",
    "    \n",
    "    # Accumulate gradients\n",
    "    for _ in range(num_batches):\n",
    "        x, y = next(data_iter)\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        \n",
    "        model.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "        \n",
    "        for name, m in model.named_modules():\n",
    "            if isinstance(m, (nn.Conv2d, nn.Linear)) and m.weight.grad is not None:\n",
    "                grad_accum[name] += torch.abs(m.weight.grad)\n",
    "    \n",
    "    # Average gradients\n",
    "    for name in grad_accum:\n",
    "        grad_accum[name] /= num_batches\n",
    "    \n",
    "    # Collect all saliency scores\n",
    "    scores = []\n",
    "    for name, m in model1.named_modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            scores.append(grad_accum[name].view(-1))\n",
    "    \n",
    "    if not scores:\n",
    "        return model, 0.0\n",
    "    \n",
    "    all_scores = torch.cat(scores)\n",
    "    w = torch.abs(all_scores)/torch.sum(torch.abs(all_scores))\n",
    "    neff = 1/torch.sum(w ** 2)\n",
    "    r_neff = torch.clamp(torch.floor(beta * neff), 1, len(all_scores)-1)\n",
    "    threshold = torch.quantile(torch.abs(all_scores), r_neff / len(all_scores))\n",
    "\n",
    "    # Apply mask\n",
    "    total, kept = 0, 0\n",
    "    for name, m in model1.named_modules():\n",
    "        if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "            mask = grad_accum[name] > threshold\n",
    "            m.weight.data.mul_(mask)\n",
    "            kept += mask.sum().item()\n",
    "            total += mask.numel()\n",
    "    \n",
    "    actual_sparsity = 1.0 - (kept / total)\n",
    "    return model1, actual_sparsity\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd5817dc",
   "metadata": {},
   "source": [
    "# loss change, fc5, fc12, vgg, alexnet, res18, res50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "99d1bba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = ['FC5', 'FC12', 'VGG16', 'AlexNet', 'ResNet18_C100', 'ResNet50_C100', 'ResNet18_TinyIN', 'ResNet50_TinyIN']\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1fd83749",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded ./checkpoints/FC5_CIFAR10_best.pth\n",
      "before pruning: val_loss=1.287621498298645 | val_acc=57.34%\n",
      "Achieved global sparsity: 48.594308524336846%\n",
      "After pruning: val_loss=1.2694350761413575 | val_acc=57.52% | change of loss=0.018186422157287385  |  | change of Acc=-0.17999999999999972%\n",
      "Loaded ./checkpoints/FC12_CIFAR10_best.pth\n",
      "before pruning: val_loss=1.2540489635467529 | val_acc=58.26%\n",
      "Achieved global sparsity: 39.328941208866866%\n",
      "After pruning: val_loss=1.231849143409729 | val_acc=58.42% | change of loss=0.02219982013702393  |  | change of Acc=-0.1600000000000037%\n",
      "Loaded ./checkpoints/VGG16_CIFAR10_best.pth\n",
      "before pruning: val_loss=0.4234209942817688 | val_acc=91.12%\n",
      "Achieved global sparsity: 59.47641629445355%\n",
      "After pruning: val_loss=0.31842048754692076 | val_acc=90.98% | change of loss=0.10500050673484806  |  | change of Acc=0.14000000000000057%\n",
      "Loaded ./checkpoints/AlexNet_CIFAR10_best.pth\n",
      "before pruning: val_loss=0.48460997281074525 | val_acc=89.84%\n",
      "Achieved global sparsity: 60.65771789393602%\n",
      "After pruning: val_loss=0.448030006980896 | val_acc=89.89% | change of loss=0.03657996582984924  |  | change of Acc=-0.04999999999999716%\n",
      "Loaded ./checkpoints/ResNet18_C100_CIFAR100_best.pth\n",
      "before pruning: val_loss=0.873947188949585 | val_acc=78.64%\n",
      "Achieved global sparsity: 56.19920802338394%\n",
      "After pruning: val_loss=0.9286916805267333 | val_acc=77.55% | change of loss=-0.054744491577148335  |  | change of Acc=1.0900000000000034%\n",
      "Loaded ./checkpoints/ResNet50_C100_CIFAR100_best.pth\n",
      "before pruning: val_loss=0.8449632821083068 | val_acc=80.08%\n",
      "Achieved global sparsity: 54.27148923187657%\n",
      "After pruning: val_loss=0.8519582453727722 | val_acc=79.08% | change of loss=-0.006994963264465359  |  | change of Acc=1.0%\n",
      "Loaded ./checkpoints/ResNet18_TinyIN_TinyImageNet_best.pth\n",
      "before pruning: val_loss=3.9637248611450193 | val_acc=14.44%\n",
      "Achieved global sparsity: 43.66664087390606%\n",
      "After pruning: val_loss=4.067306075286865 | val_acc=13.46% | change of loss=-0.10358121414184573  |  | change of Acc=0.9799999999999986%\n",
      "Loaded ./checkpoints/ResNet50_TinyIN_TinyImageNet_best.pth\n",
      "before pruning: val_loss=2.0287342670440673 | val_acc=56.71%\n",
      "Achieved global sparsity: 48.10430651169402%\n",
      "After pruning: val_loss=1.9927550197601318 | val_acc=55.77% | change of loss=0.035979247283935534  |  | change of Acc=0.9399999999999977%\n"
     ]
    }
   ],
   "source": [
    "for key in keys:\n",
    "    model_config = cfgs[key]\n",
    "    train_loader, test_loader = get_dataloaders(model_config)\n",
    "    model = build_model(model_config)\n",
    "    tag = f\"{model_config.name}_{model_config.dataset}\"\n",
    "    ckpt_path = f'./checkpoints/{tag}_best.pth'\n",
    "    if os.path.exists(ckpt_path):\n",
    "        state = torch.load(ckpt_path, map_location=device)\n",
    "        model.load_state_dict(state[\"model\"])\n",
    "        print(\"Loaded\", ckpt_path)\n",
    "        \n",
    "        pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)\n",
    "    print(f\"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%\")\n",
    "\n",
    "    sparsity = EMP_global_magnitude(model)\n",
    "    print(f\"Achieved global sparsity: {sparsity*100}%\")\n",
    "\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    val_loss, val_acc = evaluate(model, test_loader, criterion)\n",
    "    print(f\"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6874d467",
   "metadata": {},
   "source": [
    "# swap beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "00aced84",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = [\"FC2-M\",\"FC5-M\",\"FC12-M\",\"FC2-FM\",\"FC5-FM\",\"FC12-FM\"]\n",
    "beta_list = [0.5, 0.75, 1.0, 1.25, 1.5, 2.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cf5c250c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded ./checkpoints/FC2_MNIST_best.pth\n",
      "before pruning: val_loss=0.21534884063005447 | val_acc=93.8%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 68.48362720403023%\n",
      "After pruning: val_loss=0.24116790586709977 | val_acc=93.57% | change of loss=-0.025819065237045302  |  | change of Acc=0.23000000000000398%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 52.72544080604534%\n",
      "After pruning: val_loss=0.22307048230171203 | val_acc=93.81% | change of loss=-0.007721641671657564  |  | change of Acc=-0.010000000000005116%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 36.968513853904284%\n",
      "After pruning: val_loss=0.21815758819580078 | val_acc=93.78% | change of loss=-0.0028087475657463112  |  | change of Acc=0.01999999999999602%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 21.210327455919398%\n",
      "After pruning: val_loss=0.2155759802699089 | val_acc=93.81% | change of loss=-0.00022713963985443453  |  | change of Acc=-0.010000000000005116%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 5.4521410579345115%\n",
      "After pruning: val_loss=0.21537971198558808 | val_acc=93.8% | change of loss=-3.0871355533618194e-05  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.21534884063005447 | val_acc=93.8% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "Loaded ./checkpoints/FC5_MNIST_best.pth\n",
      "before pruning: val_loss=0.09303029469400645 | val_acc=97.18%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 64.8103448275862%\n",
      "After pruning: val_loss=0.1297045162320137 | val_acc=96.47% | change of loss=-0.03667422153800726  |  | change of Acc=0.710000000000008%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 47.215548589341694%\n",
      "After pruning: val_loss=0.10041035265773535 | val_acc=96.94% | change of loss=-0.007380057963728898  |  | change of Acc=0.2400000000000091%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 29.620689655172416%\n",
      "After pruning: val_loss=0.09303944034352898 | val_acc=97.2% | change of loss=-9.145649522535049e-06  |  | change of Acc=-0.01999999999999602%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 12.025893416927902%\n",
      "After pruning: val_loss=0.09334719210714101 | val_acc=97.2% | change of loss=-0.00031689741313456476  |  | change of Acc=-0.01999999999999602%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.09303029469400645 | val_acc=97.18% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.09303029469400645 | val_acc=97.18% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "Loaded ./checkpoints/FC12_MNIST_best.pth\n",
      "before pruning: val_loss=0.11948255969975144 | val_acc=97.05%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 64.16042211055276%\n",
      "After pruning: val_loss=0.36684334478378294 | val_acc=93.15% | change of loss=-0.2473607850840315  |  | change of Acc=3.8999999999999915%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 46.2406231155779%\n",
      "After pruning: val_loss=0.14520602426826953 | val_acc=96.36% | change of loss=-0.025723464568518095  |  | change of Acc=0.6899999999999977%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 28.320824120603017%\n",
      "After pruning: val_loss=0.11955472102407366 | val_acc=96.94% | change of loss=-7.2161324322223e-05  |  | change of Acc=0.10999999999999943%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 10.401045226130655%\n",
      "After pruning: val_loss=0.11816273110006005 | val_acc=97.09% | change of loss=0.0013198285996913889  |  | change of Acc=-0.04000000000000625%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.11948255969975144 | val_acc=97.05% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.11948255969975144 | val_acc=97.05% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "Loaded ./checkpoints/FC2_FashionMNIST_best.pth\n",
      "before pruning: val_loss=0.43519594194889066 | val_acc=84.48%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 68.0617128463476%\n",
      "After pruning: val_loss=0.45921978254318235 | val_acc=83.71% | change of loss=-0.02402384059429169  |  | change of Acc=0.7700000000000102%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 52.09319899244333%\n",
      "After pruning: val_loss=0.43980459237098696 | val_acc=84.31% | change of loss=-0.004608650422096294  |  | change of Acc=0.1700000000000017%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 36.12468513853905%\n",
      "After pruning: val_loss=0.43587259707450865 | val_acc=84.36% | change of loss=-0.0006766551256179865  |  | change of Acc=0.12000000000000455%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 20.15617128463476%\n",
      "After pruning: val_loss=0.4345844689130783 | val_acc=84.36% | change of loss=0.0006114730358123821  |  | change of Acc=0.12000000000000455%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 4.187657430730473%\n",
      "After pruning: val_loss=0.43521507568359374 | val_acc=84.5% | change of loss=-1.9133734703080663e-05  |  | change of Acc=-0.01999999999999602%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.43519594194889066 | val_acc=84.48% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "Loaded ./checkpoints/FC5_FashionMNIST_best.pth\n",
      "before pruning: val_loss=0.35514459912776947 | val_acc=86.96%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 64.61134796238244%\n",
      "After pruning: val_loss=0.40074955146312713 | val_acc=86.44% | change of loss=-0.04560495233535766  |  | change of Acc=0.519999999999996%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 46.91705329153605%\n",
      "After pruning: val_loss=0.35902017390727997 | val_acc=86.84% | change of loss=-0.0038755747795105044  |  | change of Acc=0.11999999999999034%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 29.22275862068966%\n",
      "After pruning: val_loss=0.35611285412311555 | val_acc=86.86% | change of loss=-0.0009682549953460851  |  | change of Acc=0.09999999999999432%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 11.52846394984326%\n",
      "After pruning: val_loss=0.3551785221934319 | val_acc=86.96% | change of loss=-3.392306566241121e-05  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.35514459912776947 | val_acc=86.96% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.35514459912776947 | val_acc=86.96% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "Loaded ./checkpoints/FC12_FashionMNIST_best.pth\n",
      "before pruning: val_loss=0.3501719362020492 | val_acc=87.85%\n",
      "\n",
      "beta = 0.5\n",
      "Achieved global sparsity: 64.0462311557789%\n",
      "After pruning: val_loss=0.7277445816040039 | val_acc=78.62% | change of loss=-0.3775726454019547  |  | change of Acc=9.22999999999999%\n",
      "\n",
      "beta = 0.75\n",
      "Achieved global sparsity: 46.06934673366834%\n",
      "After pruning: val_loss=0.364203066444397 | val_acc=87.5% | change of loss=-0.014031130242347756  |  | change of Acc=0.3499999999999943%\n",
      "\n",
      "beta = 1.0\n",
      "Achieved global sparsity: 28.092482412060306%\n",
      "After pruning: val_loss=0.3492264039516449 | val_acc=87.69% | change of loss=0.0009455322504043351  |  | change of Acc=0.1599999999999966%\n",
      "\n",
      "beta = 1.25\n",
      "Achieved global sparsity: 10.115597989949753%\n",
      "After pruning: val_loss=0.34992141020298007 | val_acc=87.88% | change of loss=0.0002505259990691622  |  | change of Acc=-0.030000000000001137%\n",
      "\n",
      "beta = 1.5\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.3501719362020492 | val_acc=87.85% | change of loss=0.0  |  | change of Acc=0.0%\n",
      "\n",
      "beta = 2.0\n",
      "Achieved global sparsity: 0.0%\n",
      "After pruning: val_loss=0.3501719362020492 | val_acc=87.85% | change of loss=0.0  |  | change of Acc=0.0%\n"
     ]
    }
   ],
   "source": [
    "for key in keys:\n",
    "    model_config = cfgs[key]\n",
    "    train_loader, test_loader = get_dataloaders(model_config)\n",
    "    model = build_model(model_config)\n",
    "    tag = f\"{model_config.name}_{model_config.dataset}\"\n",
    "    ckpt_path = f'./checkpoints/{tag}_best.pth'\n",
    "    if os.path.exists(ckpt_path):\n",
    "        state = torch.load(ckpt_path, map_location=device)\n",
    "        model.load_state_dict(state[\"model\"])\n",
    "        print(\"Loaded\", ckpt_path)\n",
    "        \n",
    "        pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)\n",
    "    print(f\"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%\")\n",
    "    \n",
    "    for beta in beta_list:\n",
    "        print(f'\\nbeta = {beta}')\n",
    "        new_model = copy.deepcopy(model)\n",
    "        sparsity = EMP_global_magnitude(new_model, beta)\n",
    "        print(f\"Achieved global sparsity: {sparsity*100}%\")\n",
    "\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        val_loss, val_acc = evaluate(new_model, test_loader, criterion)\n",
    "        print(f\"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  change of Acc={pre_val_acc-val_acc}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b78e32fc",
   "metadata": {},
   "source": [
    "# EMP loss change & saliency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a85d65c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded ./checkpoints/VGG16_CIFAR10_best.pth\n",
      "before pruning: val_loss=0.4234209942817688 | val_acc=91.12%\n"
     ]
    }
   ],
   "source": [
    "model_config = cfgs['VGG16']\n",
    "train_loader, test_loader = get_dataloaders(model_config)\n",
    "model = build_model(model_config)\n",
    "tag = f\"{model_config.name}_{model_config.dataset}\"\n",
    "ckpt_path = f'./checkpoints/{tag}_best.pth'\n",
    "if os.path.exists(ckpt_path):\n",
    "    state = torch.load(ckpt_path, map_location=device)\n",
    "    model.load_state_dict(state[\"model\"])\n",
    "    print(\"Loaded\", ckpt_path)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)\n",
    "print(f\"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d4583fed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EMP loss change\n",
      "Achieved global sparsity: 2.0700775450026354%\n",
      "After pruning: val_loss=0.4233589924812317 | val_acc=91.11% | change of loss=6.200180053711479e-05  |  | change of Acc=0.010000000000005116%\n"
     ]
    }
   ],
   "source": [
    "model1 = copy.deepcopy(model)\n",
    "model1, sparsity = EMP_loss_change(model1, train_loader)\n",
    "print(f\"EMP loss change\")\n",
    "print(f\"Achieved global sparsity: {sparsity*100}%\")\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "val_loss, val_acc = evaluate(model1, test_loader, criterion)\n",
    "print(f\"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2173f7db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EMP loss change\n",
      "Achieved global sparsity: 25.72811245547577%\n",
      "After pruning: val_loss=0.3816164824008942 | val_acc=90.97% | change of loss=0.04180451188087464  |  | change of Acc=0.15000000000000568%\n"
     ]
    }
   ],
   "source": [
    "model2 = copy.deepcopy(model)\n",
    "model2, sparsity = EMP_saliency(model2, test_loader)\n",
    "print(f\"EMP loss change\")\n",
    "print(f\"Achieved global sparsity: {sparsity*100}%\")\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "val_loss, val_acc = evaluate(model2, test_loader, criterion)\n",
    "print(f\"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "069a418f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
