{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Install AIHWKIT\n",
        "- This installation is for a GPU enabled Google Colab environment.\n",
        "If you're using it on your local machine, please refer to the [documentation](https://aihwkit.readthedocs.io/en/latest/install.html)\n"
      ],
      "metadata": {
        "id": "XnMyj4kHY4rH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget https://aihwkit-gpu-demo.s3.us-east.cloud-object-storage.appdomain.cloud/aihwkit-0.8.0+cuda117-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl\n",
        "!pip install aihwkit-0.8.0+cuda117-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
      ],
      "metadata": {
        "id": "s6dBOCsaGzoM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Install monai package\n",
        "- For advanced settings, please refer to the [documentation](https://docs.monai.io/en/latest/installation.html#installation-guide)"
      ],
      "metadata": {
        "id": "FsbVLeWfZVCw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python -c \"import monai\" || pip install -q \"monai-weekly[gdown, nibabel, tqdm, ignite, einops]\"\n",
        "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
        "%matplotlib inline"
      ],
      "metadata": {
        "id": "EraNm11VHnq0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a72c5928-1442-4f22-da9c-64e3edd8a7e1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2023-10-30 11:35:56.492771: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
            "2023-10-30 11:35:56.492827: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
            "2023-10-30 11:35:56.492864: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
            "2023-10-30 11:35:56.500947: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
            "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
            "2023-10-30 11:35:58.928622: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Dataset download\n",
        "- We use the same transformations and dependencies as [nested-unet](https://github.com/4uiiurz1/pytorch-nested-unet)(UNet++), so we clone the repo and use the same repo structure"
      ],
      "metadata": {
        "id": "to2brybmZ6UH"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m88RmTo9Gm4e",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "363fc3db-b921-42b7-809a-71f947479597"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fatal: destination path 'pytorch-nested-unet' already exists and is not an empty directory.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://github.com/4uiiurz1/pytorch-nested-unet"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd pytorch-nested-unet"
      ],
      "metadata": {
        "id": "JQFNEb34Gy0A",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "39045ade-42ce-4153-d5eb-0d65c0035528"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/pytorch-nested-unet\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -r requirements.txt"
      ],
      "metadata": {
        "id": "NEfn9lLYGvPR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Mount dataset from kaggle\n",
        "- Upload your kaggle.json file and execute the following cells, for more details please refer to this [blog](https://www.analyticsvidhya.com/blog/2021/06/how-to-load-kaggle-datasets-directly-into-google-colab/)"
      ],
      "metadata": {
        "id": "MqzZ_H6FaeEP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!mkdir ~/.kaggle\n",
        "!cp /content/kaggle.json ~/.kaggle\n",
        "!chmod 600 ~/.kaggle/kaggle.json\n",
        "!kaggle competitions download data-science-bowl-2018"
      ],
      "metadata": {
        "id": "fQMvfj_9GvRd",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4faec93e-313e-48db-c211-0a3cc7aa5ace"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "mkdir: cannot create directory ‘/root/.kaggle’: File exists\n",
            "Downloading data-science-bowl-2018.zip to /content/pytorch-nested-unet\n",
            " 97% 349M/358M [00:03<00:00, 137MB/s]\n",
            "100% 358M/358M [00:03<00:00, 119MB/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!unzip /content/pytorch-nested-unet/data-science-bowl-2018.zip -d /content/pytorch-nested-unet/inputs\n",
        "!unzip /content/pytorch-nested-unet/inputs/stage1_train.zip -d /content/pytorch-nested-unet/inputs/stage1_train/\n",
        "!unzip /content/pytorch-nested-unet/inputs/stage1_test.zip -d /content/pytorch-nested-unet/inputs/stage1_test/\n",
        "!unzip /content/pytorch-nested-unet/inputs/stage1_train_labels.csv.zip -d /content/pytorch-nested-unet/inputs/stage1_train_labels/"
      ],
      "metadata": {
        "id": "8mtwUO6rGvTp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%cd /content/pytorch-nested-unet/\n",
        "!pwd"
      ],
      "metadata": {
        "id": "2VYmv8sFGvVp",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e4559dcf-43f0-48c2-d276-c7215fef056e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/pytorch-nested-unet\n",
            "/content/pytorch-nested-unet\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Data loading in Colab:"
      ],
      "metadata": {
        "id": "wsoAOWw4frDP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from glob import glob\n",
        "\n",
        "import cv2\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "def main():\n",
        "    img_size = 96\n",
        "\n",
        "    paths = glob('/content/pytorch-nested-unet/inputs/stage1_train/*')\n",
        "\n",
        "    os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)\n",
        "    os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)\n",
        "    print(\"here\")\n",
        "    for i in tqdm(range(len(paths))):\n",
        "        print(i)\n",
        "        path = paths[i]\n",
        "        img = cv2.imread(os.path.join(path, 'images',#\n",
        "                         os.path.basename(path) + '.png'))\n",
        "        mask = np.zeros((img.shape[0], img.shape[1]))\n",
        "        for mask_path in glob(os.path.join(path, 'masks', '*')):\n",
        "            mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127\n",
        "            mask[mask_] = 1\n",
        "        if len(img.shape) == 2:\n",
        "            img = np.tile(img[..., None], (1, 1, 3))\n",
        "        if img.shape[2] == 4:\n",
        "            img = img[..., :3]\n",
        "        img = cv2.resize(img, (img_size, img_size))\n",
        "        mask = cv2.resize(mask, (img_size, img_size))\n",
        "        cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,\n",
        "                    os.path.basename(path) + '.png'), img)\n",
        "        cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,\n",
        "                    os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "sTPpHGovGvX2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Helper functions:"
      ],
      "metadata": {
        "id": "J4WjjtBLgAcg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import cv2\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.utils.data\n",
        "\n",
        "\n",
        "class Dataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            img_ids (list): Image ids.\n",
        "            img_dir: Image file directory.\n",
        "            mask_dir: Mask file directory.\n",
        "            img_ext (str): Image file extension.\n",
        "            mask_ext (str): Mask file extension.\n",
        "            num_classes (int): Number of classes.\n",
        "            transform (Compose, optional): Compose transforms of albumentations. Defaults to None.\n",
        "\n",
        "        Note:\n",
        "            Make sure to put the files as the following structure:\n",
        "            <dataset name>\n",
        "            ├── images\n",
        "            |   ├── 0a7e06.jpg\n",
        "            │   ├── 0aab0a.jpg\n",
        "            │   ├── 0b1761.jpg\n",
        "            │   ├── ...\n",
        "            |\n",
        "            └── masks\n",
        "                ├── 0\n",
        "                |   ├── 0a7e06.png\n",
        "                |   ├── 0aab0a.png\n",
        "                |   ├── 0b1761.png\n",
        "                |   ├── ...\n",
        "                |\n",
        "                ├── 1\n",
        "                |   ├── 0a7e06.png\n",
        "                |   ├── 0aab0a.png\n",
        "                |   ├── 0b1761.png\n",
        "                |   ├── ...\n",
        "                ...\n",
        "        \"\"\"\n",
        "        self.img_ids = img_ids\n",
        "        self.img_dir = img_dir\n",
        "        self.mask_dir = mask_dir\n",
        "        self.img_ext = img_ext\n",
        "        self.mask_ext = mask_ext\n",
        "        self.num_classes = num_classes\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.img_ids)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_id = self.img_ids[idx]\n",
        "\n",
        "        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))\n",
        "\n",
        "        mask = []\n",
        "        for i in range(self.num_classes):\n",
        "            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),\n",
        "                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])\n",
        "        mask = np.dstack(mask)\n",
        "\n",
        "        if self.transform is not None:\n",
        "            augmented = self.transform(image=img, mask=mask)\n",
        "            img = augmented['image']\n",
        "            mask = augmented['mask']\n",
        "\n",
        "        img = img.astype('float32') / 255\n",
        "        img = img.transpose(2, 0, 1)\n",
        "        mask = mask.astype('float32') / 255\n",
        "        mask = mask.transpose(2, 0, 1)\n",
        "\n",
        "        return img, mask, {'img_id': img_id}"
      ],
      "metadata": {
        "id": "msHL3n3DGvZy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Metric functions\n"
      ],
      "metadata": {
        "id": "FX01faBbgF13"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "def iou_score(output, target):\n",
        "    smooth = 1e-5\n",
        "\n",
        "    if torch.is_tensor(output):\n",
        "        output = torch.sigmoid(output).data.cpu().numpy()\n",
        "    if torch.is_tensor(target):\n",
        "        target = target.data.cpu().numpy()\n",
        "    output_ = output > 0.5\n",
        "    target_ = target > 0.5\n",
        "    intersection = (output_ & target_).sum()\n",
        "    union = (output_ | target_).sum()\n",
        "\n",
        "    return (intersection + smooth) / (union + smooth)\n",
        "\n",
        "\n",
        "def dice_coef(output, target):\n",
        "    smooth = 1e-5\n",
        "\n",
        "    output = torch.sigmoid(output).view(-1).data.cpu().numpy()\n",
        "    target = target.view(-1).data.cpu().numpy()\n",
        "    intersection = (output * target).sum()\n",
        "\n",
        "    return (2. * intersection + smooth) / \\\n",
        "        (output.sum() + target.sum() + smooth)"
      ],
      "metadata": {
        "id": "IBXBInw2Gvci"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import argparse\n",
        "\n",
        "\n",
        "def str2bool(v):\n",
        "    if v.lower() in ['true', 1]:\n",
        "        return True\n",
        "    elif v.lower() in ['false', 0]:\n",
        "        return False\n",
        "    else:\n",
        "        raise argparse.ArgumentTypeError('Boolean value expected.')\n",
        "\n",
        "\n",
        "def count_params(model):\n",
        "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "class AverageMeter(object):\n",
        "    \"\"\"Computes and stores the average and current value\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.reset()\n",
        "\n",
        "    def reset(self):\n",
        "        self.val = 0\n",
        "        self.avg = 0\n",
        "        self.sum = 0\n",
        "        self.count = 0\n",
        "\n",
        "    def update(self, val, n=1):\n",
        "        self.val = val\n",
        "        self.sum += val * n\n",
        "        self.count += n\n",
        "        self.avg = self.sum / self.count"
      ],
      "metadata": {
        "id": "OoDjgZF5GvvY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Loss functions"
      ],
      "metadata": {
        "id": "wx4zp_YAgNjb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "try:\n",
        "    from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge\n",
        "except ImportError:\n",
        "    pass\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "class BCEDiceLoss(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(self, input, target):\n",
        "        bce = F.binary_cross_entropy_with_logits(input, target)\n",
        "        smooth = 1e-5\n",
        "        input = torch.sigmoid(input)\n",
        "        num = target.size(0)\n",
        "        input = input.view(num, -1)\n",
        "        target = target.view(num, -1)\n",
        "        intersection = (input * target)\n",
        "        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)\n",
        "        dice = 1 - dice.sum() / num\n",
        "        return 0.5 * bce + dice\n",
        "\n",
        "\n",
        "class LovaszHingeLoss(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(self, input, target):\n",
        "        input = input.squeeze(1)\n",
        "        target = target.squeeze(1)\n",
        "        loss = lovasz_hinge(input, target, per_image=True)\n",
        "\n",
        "        return loss"
      ],
      "metadata": {
        "id": "dIlgXExlGvxb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Swin_UNETR\n",
        "model from monai package"
      ],
      "metadata": {
        "id": "RElrFu-HByR2"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Configuration of parameters:"
      ],
      "metadata": {
        "id": "BxQp3JPTgRbV"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "config = {\n",
        "        'name': None,\n",
        "        'epochs': 200,\n",
        "        'batch_size': 64,\n",
        "        'arch': 'NestedUNet',\n",
        "        'deep_supervision': False,\n",
        "        'input_channels': 3,\n",
        "        'num_classes': 1,\n",
        "        'input_w': 96,\n",
        "        'input_h': 96,\n",
        "        'loss': 'BCEDiceLoss',\n",
        "        'dataset': 'dsb2018_96',\n",
        "        'img_ext': '.png',\n",
        "        'mask_ext': '.png',\n",
        "        'optimizer': 'SGD',\n",
        "        'lr': 1e-3,\n",
        "        'momentum': 0.9,\n",
        "        'weight_decay': 1e-4,\n",
        "        'nesterov': False,\n",
        "        'scheduler': 'CosineAnnealingLR',\n",
        "        'min_lr': 1e-5,\n",
        "        'factor': 0.1,\n",
        "        'patience': 2,\n",
        "        'milestones': '1,2',\n",
        "        'gamma': 2/3,\n",
        "        'early_stopping': -1,\n",
        "        'num_workers': 4\n",
        "    }"
      ],
      "metadata": {
        "id": "59KiNnLsGvzw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Training of digital model"
      ],
      "metadata": {
        "id": "HuBxoahVgZTJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import argparse\n",
        "import os\n",
        "from collections import OrderedDict\n",
        "from glob import glob\n",
        "\n",
        "import pandas as pd\n",
        "import torch\n",
        "import torch.backends.cudnn as cudnn\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import yaml\n",
        "from albumentations.augmentations import transforms\n",
        "from albumentations.core.composition import Compose, OneOf\n",
        "from sklearn.model_selection import train_test_split\n",
        "from torch.optim import lr_scheduler\n",
        "from tqdm import tqdm\n",
        "from albumentations.augmentations.geometric.rotate import RandomRotate90\n",
        "from albumentations.augmentations.geometric import Flip, Resize\n",
        "\n",
        "\n",
        "\n",
        "ARCH_NAMES = ['UNet', 'NestedUNet']\n",
        "LOSS_NAMES = ['BCEDiceLoss', 'LovaszHingeLoss']\n",
        "LOSS_NAMES.append('BCEWithLogitsLoss')\n",
        "\n",
        "\n",
        "\n",
        "def train(config, train_loader, model, criterion, optimizer):\n",
        "    avg_meters = {'loss': AverageMeter(),\n",
        "                  'iou': AverageMeter()}\n",
        "\n",
        "    model.train()\n",
        "\n",
        "    pbar = tqdm(total=len(train_loader))\n",
        "    for input, target, _ in train_loader:\n",
        "        input = input.cuda()\n",
        "        target = target.cuda()\n",
        "\n",
        "        # compute output\n",
        "        if config['deep_supervision']:\n",
        "            outputs = model(input)\n",
        "            loss = 0\n",
        "            for output in outputs:\n",
        "                loss += criterion(output, target)\n",
        "            loss /= len(outputs)\n",
        "            iou = iou_score(outputs[-1], target)\n",
        "        else:\n",
        "            output = model(input)\n",
        "            loss = criterion(output, target)\n",
        "            iou = iou_score(output, target)\n",
        "\n",
        "        # compute gradient and do optimizing step\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        avg_meters['loss'].update(loss.item(), input.size(0))\n",
        "        avg_meters['iou'].update(iou, input.size(0))\n",
        "\n",
        "        postfix = OrderedDict([\n",
        "            ('loss', avg_meters['loss'].avg),\n",
        "            ('iou', avg_meters['iou'].avg),\n",
        "        ])\n",
        "        pbar.set_postfix(postfix)\n",
        "        pbar.update(1)\n",
        "    pbar.close()\n",
        "\n",
        "    return OrderedDict([('loss', avg_meters['loss'].avg),\n",
        "                        ('iou', avg_meters['iou'].avg)])\n",
        "\n",
        "\n",
        "def validate(config, val_loader, model, criterion):\n",
        "    avg_meters = {'loss': AverageMeter(),\n",
        "                  'iou': AverageMeter()}\n",
        "\n",
        "    # switch to evaluate mode\n",
        "    model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        pbar = tqdm(total=len(val_loader))\n",
        "        for input, target, _ in val_loader:\n",
        "            input = input.cuda()\n",
        "            target = target.cuda()\n",
        "\n",
        "            # compute output\n",
        "            if config['deep_supervision']:\n",
        "                outputs = model(input)\n",
        "                loss = 0\n",
        "                for output in outputs:\n",
        "                    loss += criterion(output, target)\n",
        "                loss /= len(outputs)\n",
        "                iou = iou_score(outputs[-1], target)\n",
        "            else:\n",
        "                output = model(input)\n",
        "                loss = criterion(output, target)\n",
        "                iou = iou_score(output, target)\n",
        "\n",
        "            avg_meters['loss'].update(loss.item(), input.size(0))\n",
        "            avg_meters['iou'].update(iou, input.size(0))\n",
        "\n",
        "            postfix = OrderedDict([\n",
        "                ('loss', avg_meters['loss'].avg),\n",
        "                ('iou', avg_meters['iou'].avg),\n",
        "            ])\n",
        "            pbar.set_postfix(postfix)\n",
        "            pbar.update(1)\n",
        "        pbar.close()\n",
        "\n",
        "    return OrderedDict([('loss', avg_meters['loss'].avg),\n",
        "                        ('iou', avg_meters['iou'].avg)])\n",
        "\n",
        "\n",
        "if config['name'] is None:\n",
        "    if config['deep_supervision']:\n",
        "        config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])\n",
        "    else:\n",
        "        config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])\n",
        "os.makedirs('models/%s' % config['name'], exist_ok=True)\n",
        "\n",
        "print('-' * 20)\n",
        "for key in config:\n",
        "    print('%s: %s' % (key, config[key]))\n",
        "print('-' * 20)\n",
        "\n",
        "with open('models/%s/config.yml' % config['name'], 'w') as f:\n",
        "    yaml.dump(config, f)\n",
        "\n",
        "# define loss function (criterion)\n",
        "if config['loss'] == 'BCEWithLogitsLoss':\n",
        "    criterion = nn.BCEWithLogitsLoss().cuda()\n",
        "elif config['loss'] == 'BCEDiceLoss':\n",
        "    criterion = BCEDiceLoss().cuda()\n",
        "else:\n",
        "    criterion = LovaszHingeLoss().cuda()\n",
        "\n",
        "cudnn.benchmark = True\n",
        "\n",
        "# create model ------------------------------------------------------------------------------------\n",
        "# -------------------------------------------------------------------------------------------------\n",
        "import monai\n",
        "from monai.networks.nets import SwinUNETR\n",
        "\n",
        "def create_model(config):\n",
        "    # Adjust the parameters below as needed, ensuring they align with your specific task, dataset, and preferences.\n",
        "    model = SwinUNETR(\n",
        "        img_size=(config['input_w'], config['input_h']),  # for 2D images; no depth dimension\n",
        "        in_channels=3,  # adjusted to match your grayscale images\n",
        "        out_channels=config['num_classes'],\n",
        "        depths=(2, 2, 2, 2),  # This is an example; you should configure it based on your requirements.\n",
        "        num_heads=(3, 6, 12, 24),  # These are example configurations.\n",
        "        feature_size=24,  # This is a hypothetical feature size; you might need to adjust.\n",
        "        norm_name='instance',  # You mentioned using instance norm\n",
        "        drop_rate=0.0,  # Adjust if you want to use dropout\n",
        "        attn_drop_rate=0.0,  # Dropout rate specific to the attention mechanism\n",
        "        dropout_path_rate=0.0,  # Dropout rate for stochastic depth (optional)\n",
        "        normalize=True,  # Specify whether to normalize input images\n",
        "        use_checkpoint=False,  # Specify whether to use checkpointing to save memory\n",
        "        spatial_dims=2,  # Since you're working with 3D data\n",
        "        downsample='merging',  # Strategy for downsampling, 'merging' is commonly used\n",
        "        use_v2=False  # Specify whether to use the updated version of Swin Transformer\n",
        "    )\n",
        "\n",
        "    return model\n",
        "config['input_channels'] = 3\n",
        "# Now, create your model instance with the provided config\n",
        "model = create_model(config)\n",
        "\n",
        "\n",
        "# If you have a specific task in mind, like segmentation, you might want to add a final activation layer\n",
        "# model.add_module('final_activation', torch.nn.Softmax(dim=1))  # For multi-class segmentation\n",
        "\n",
        "# Now, `model` is ready to be used in your training loop, and you can proceed with your regular training steps.\n",
        "\n",
        "\n",
        "model = model.cuda()\n",
        "\n",
        "params = filter(lambda p: p.requires_grad, model.parameters())\n",
        "if config['optimizer'] == 'Adam':\n",
        "    optimizer = optim.Adam(\n",
        "        params, lr=config['lr'], weight_decay=config['weight_decay'])\n",
        "elif config['optimizer'] == 'SGD':\n",
        "    optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],\n",
        "                          nesterov=config['nesterov'], weight_decay=config['weight_decay'])\n",
        "else:\n",
        "    raise NotImplementedError\n",
        "\n",
        "if config['scheduler'] == 'CosineAnnealingLR':\n",
        "    scheduler = lr_scheduler.CosineAnnealingLR(\n",
        "        optimizer, T_max=config['epochs'], eta_min=config['min_lr'])\n",
        "elif config['scheduler'] == 'ReduceLROnPlateau':\n",
        "    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],\n",
        "                                                verbose=1, min_lr=config['min_lr'])\n",
        "elif config['scheduler'] == 'MultiStepLR':\n",
        "    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])\n",
        "elif config['scheduler'] == 'ConstantLR':\n",
        "    scheduler = None\n",
        "else:\n",
        "    raise NotImplementedError\n",
        "\n",
        "# Data loading code\n",
        "img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))\n",
        "img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]\n",
        "\n",
        "train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)\n",
        "\n",
        "train_transform = Compose([\n",
        "    RandomRotate90(),\n",
        "    Flip(),\n",
        "    OneOf([\n",
        "        transforms.HueSaturationValue(),\n",
        "        transforms.RandomBrightness(),\n",
        "        transforms.RandomContrast(),\n",
        "    ], p=1),\n",
        "    Resize(config['input_h'], config['input_w']),\n",
        "    transforms.Normalize(),\n",
        "])\n",
        "\n",
        "val_transform = Compose([\n",
        "    Resize(config['input_h'], config['input_w']),\n",
        "    transforms.Normalize(),\n",
        "])\n",
        "\n",
        "train_dataset = Dataset(\n",
        "    img_ids=train_img_ids,\n",
        "    img_dir=os.path.join('inputs', config['dataset'], 'images'),\n",
        "    mask_dir=os.path.join('inputs', config['dataset'], 'masks'),\n",
        "    img_ext=config['img_ext'],\n",
        "    mask_ext=config['mask_ext'],\n",
        "    num_classes=config['num_classes'],\n",
        "    transform=train_transform)\n",
        "val_dataset = Dataset(\n",
        "    img_ids=val_img_ids,\n",
        "    img_dir=os.path.join('inputs', config['dataset'], 'images'),\n",
        "    mask_dir=os.path.join('inputs', config['dataset'], 'masks'),\n",
        "    img_ext=config['img_ext'],\n",
        "    mask_ext=config['mask_ext'],\n",
        "    num_classes=config['num_classes'],\n",
        "    transform=val_transform)\n",
        "\n",
        "train_loader = torch.utils.data.DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=config['batch_size'],\n",
        "    shuffle=True,\n",
        "    num_workers=config['num_workers'],\n",
        "    drop_last=True)\n",
        "val_loader = torch.utils.data.DataLoader(\n",
        "    val_dataset,\n",
        "    batch_size=config['batch_size'],\n",
        "    shuffle=False,\n",
        "    num_workers=config['num_workers'],\n",
        "    drop_last=False)\n",
        "\n",
        "log = OrderedDict([\n",
        "    ('epoch', []),\n",
        "    ('lr', []),\n",
        "    ('loss', []),\n",
        "    ('iou', []),\n",
        "    ('val_loss', []),\n",
        "    ('val_iou', []),\n",
        "])\n",
        "\n",
        "best_iou = 0\n",
        "trigger = 0\n",
        "for epoch in range(config['epochs']):\n",
        "    print('Epoch [%d/%d]' % (epoch, config['epochs']))\n",
        "\n",
        "    # train for one epoch\n",
        "    train_log = train(config, train_loader, model, criterion, optimizer)\n",
        "    # evaluate on validation set\n",
        "    val_log = validate(config, val_loader, model, criterion)\n",
        "\n",
        "    if config['scheduler'] == 'CosineAnnealingLR':\n",
        "        scheduler.step()\n",
        "    elif config['scheduler'] == 'ReduceLROnPlateau':\n",
        "        scheduler.step(val_log['loss'])\n",
        "\n",
        "    print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'\n",
        "          % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))\n",
        "\n",
        "    log['epoch'].append(epoch)\n",
        "    log['lr'].append(config['lr'])\n",
        "    log['loss'].append(train_log['loss'])\n",
        "    log['iou'].append(train_log['iou'])\n",
        "    log['val_loss'].append(val_log['loss'])\n",
        "    log['val_iou'].append(val_log['iou'])\n",
        "\n",
        "    pd.DataFrame(log).to_csv('models/%s/log.csv' %\n",
        "                              config['name'], index=False)\n",
        "\n",
        "    trigger += 1\n",
        "\n",
        "    if val_log['iou'] > best_iou:\n",
        "        torch.save(model.state_dict(), 'models/%s/model.pth' %\n",
        "                    config['name'])\n",
        "        best_iou = val_log['iou']\n",
        "        print(\"=> saved best model\")\n",
        "        trigger = 0\n",
        "\n",
        "    # early stopping\n",
        "    if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:\n",
        "        print(\"=> early stopping\")\n",
        "        break\n",
        "\n",
        "    torch.cuda.empty_cache()\n",
        "\n"
      ],
      "metadata": {
        "id": "hKHwveTDGv2W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Analog training"
      ],
      "metadata": {
        "id": "d_sECpRQHv31"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from aihwkit.nn import AnalogConv2d, AnalogLinear, AnalogSequential\n",
        "from aihwkit.nn.conversion import convert_to_analog_mapped, convert_to_analog\n",
        "from aihwkit.optim import AnalogSGD\n",
        "from torch.optim import SGD\n",
        "from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig, UnitCellRPUConfig, InferenceRPUConfig, DigitalRankUpdateRPUConfig\n",
        "from aihwkit.simulator.configs.devices import *\n",
        "from aihwkit.simulator.configs.utils import PulseType\n",
        "# from aihwkit.simulator.rpu_base import cuda\n",
        "from aihwkit.inference import BaseNoiseModel, PCMLikeNoiseModel, StateIndependentNoiseModel\n",
        "from aihwkit.simulator.configs.utils import WeightClipType,WeightModifierType, IOParameters\n",
        "from aihwkit.inference.compensation.drift import GlobalDriftCompensation\n",
        "\n",
        "from aihwkit.simulator.configs.utils import BoundManagementType\n",
        "from aihwkit.simulator.presets.utils import PresetIOParameters\n",
        "import math"
      ],
      "metadata": {
        "id": "NwK0Dnr2Hk12"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def create_rpu_config_new():\n",
        "    rpu_config = InferenceRPUConfig()\n",
        "\n",
        "    rpu_config.clip.type = WeightClipType.FIXED_VALUE\n",
        "    rpu_config.clip.fixed_value = 1.0\n",
        "    rpu_config.modifier.pdrop = 0  # Drop connect.\n",
        "\n",
        "    rpu_config.modifier.std_dev = 0.5\n",
        "\n",
        "    rpu_config.modifier.rel_to_actual_wmax = True\n",
        "    rpu_config.mapping.digital_bias = True\n",
        "    rpu_config.mapping.weight_scaling_omega = 0.4\n",
        "    rpu_config.mapping.weight_scaling_omega = True\n",
        "    rpu_config.mapping.max_input_size = 256\n",
        "    rpu_config.mapping.max_output_size = 256\n",
        "\n",
        "    rpu_config.mapping.learn_out_scaling_alpha = True\n",
        "\n",
        "    rpu_config.forward = PresetIOParameters()\n",
        "    rpu_config.forward.inp_res = 1/256  # 8-bit DAC discretization.\n",
        "    rpu_config.forward.out_res = 1/256  # 8-bit ADC discretization.\n",
        "    rpu_config.forward.bound_management = BoundManagementType.NONE\n",
        "\n",
        "    # Inference noise model.\n",
        "    rpu_config.noise_model = PCMLikeNoiseModel(g_max=25)\n",
        "\n",
        "    # drift compensation\n",
        "    rpu_config.drift_compensation = GlobalDriftCompensation()\n",
        "    return rpu_config\n",
        "\n",
        "def create_analog_optimizer(model, lr):\n",
        "    \"\"\"Create the analog-aware optimizer.\n",
        "\n",
        "    Args:\n",
        "        model (nn.Module): model to be trained\n",
        "\n",
        "    Returns:\n",
        "        Optimizer: created analog optimizer\n",
        "    \"\"\"\n",
        "\n",
        "    optimizer = AnalogSGD(model.parameters(), lr) # we will use a learning rate of 0.01 as in the paper\n",
        "    optimizer.regroup_param_groups(model)\n",
        "\n",
        "    return optimizer"
      ],
      "metadata": {
        "id": "9FBWQ9gnHk4K"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Converting the model to analog\n",
        "- We use the rpu_config that defines the hardware and the type of noise that is applied\n"
      ],
      "metadata": {
        "id": "5XZ_Ib_tggZ2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "rpu_config = create_rpu_config_new()\n",
        "model_analog = convert_to_analog_mapped(model, rpu_config)"
      ],
      "metadata": {
        "id": "UC9zO5bMHk6g"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_analog"
      ],
      "metadata": {
        "id": "6PRiM8JTHk81"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from aihwkit.utils.analog_info import analog_summary\n",
        "analog_summary(model,(16,3,96,96),rpu_config)"
      ],
      "metadata": {
        "id": "p29TMPGnHk_P"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Analog training loop\n"
      ],
      "metadata": {
        "id": "-cUfQ4POg5Ip"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "if config['name'] is None:\n",
        "    if config['deep_supervision']:\n",
        "        config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])\n",
        "    else:\n",
        "        config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])\n",
        "os.makedirs('models/%s' % config['name'], exist_ok=True)\n",
        "\n",
        "print('-' * 20)\n",
        "for key in config:\n",
        "    print('%s: %s' % (key, config[key]))\n",
        "print('-' * 20)\n",
        "\n",
        "with open('models/%s/config.yml' % config['name'], 'w') as f:\n",
        "    yaml.dump(config, f)\n",
        "\n",
        "# define loss function (criterion)\n",
        "if config['loss'] == 'BCEWithLogitsLoss':\n",
        "    criterion = nn.BCEWithLogitsLoss().cuda()\n",
        "elif config['loss'] == 'BCEDiceLoss':\n",
        "    criterion = BCEDiceLoss().cuda()\n",
        "else:\n",
        "    criterion = LovaszHingeLoss().cuda()\n",
        "\n",
        "cudnn.benchmark = True\n",
        "\n",
        "# create model\n",
        "print(\"=> creating model %s\" % config['arch'])\n",
        "\n",
        "params = filter(lambda p: p.requires_grad, model_analog.parameters())\n",
        "if config['optimizer'] == 'Adam':\n",
        "    optimizer = optim.Adam(\n",
        "        params, lr=config['lr'], weight_decay=config['weight_decay'])\n",
        "elif config['optimizer'] == 'SGD':\n",
        "    optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],\n",
        "                          nesterov=config['nesterov'], weight_decay=config['weight_decay'])\n",
        "\n",
        "# Create an analog optimizer\n",
        "optimizer = create_analog_optimizer(model_analog, lr= config['lr'])\n",
        "\n",
        "if config['scheduler'] == 'CosineAnnealingLR':\n",
        "    scheduler = lr_scheduler.CosineAnnealingLR(\n",
        "        optimizer, T_max=config['epochs'], eta_min=config['min_lr'])\n",
        "elif config['scheduler'] == 'ReduceLROnPlateau':\n",
        "    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],\n",
        "                                                verbose=1, min_lr=config['min_lr'])\n",
        "elif config['scheduler'] == 'MultiStepLR':\n",
        "    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])\n",
        "elif config['scheduler'] == 'ConstantLR':\n",
        "    scheduler = None\n",
        "else:\n",
        "    raise NotImplementedError\n",
        "\n",
        "# Data loading code\n",
        "img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))\n",
        "img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]\n",
        "\n",
        "train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)\n",
        "\n",
        "train_transform = Compose([\n",
        "    RandomRotate90(),\n",
        "    Flip(),\n",
        "    OneOf([\n",
        "        transforms.HueSaturationValue(),\n",
        "        transforms.RandomBrightness(),\n",
        "        transforms.RandomContrast(),\n",
        "    ], p=1),\n",
        "    Resize(config['input_h'], config['input_w']),\n",
        "    transforms.Normalize(),\n",
        "])\n",
        "\n",
        "val_transform = Compose([\n",
        "    Resize(config['input_h'], config['input_w']),\n",
        "    transforms.Normalize(),\n",
        "])\n",
        "\n",
        "train_dataset = Dataset(\n",
        "    img_ids=train_img_ids,\n",
        "    img_dir=os.path.join('inputs', config['dataset'], 'images'),\n",
        "    mask_dir=os.path.join('inputs', config['dataset'], 'masks'),\n",
        "    img_ext=config['img_ext'],\n",
        "    mask_ext=config['mask_ext'],\n",
        "    num_classes=config['num_classes'],\n",
        "    transform=train_transform)\n",
        "val_dataset = Dataset(\n",
        "    img_ids=val_img_ids,\n",
        "    img_dir=os.path.join('inputs', config['dataset'], 'images'),\n",
        "    mask_dir=os.path.join('inputs', config['dataset'], 'masks'),\n",
        "    img_ext=config['img_ext'],\n",
        "    mask_ext=config['mask_ext'],\n",
        "    num_classes=config['num_classes'],\n",
        "    transform=val_transform)\n",
        "\n",
        "train_loader = torch.utils.data.DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=config['batch_size'],\n",
        "    shuffle=True,\n",
        "    num_workers=config['num_workers'],\n",
        "    drop_last=True)\n",
        "val_loader = torch.utils.data.DataLoader(\n",
        "    val_dataset,\n",
        "    batch_size=config['batch_size'],\n",
        "    shuffle=False,\n",
        "    num_workers=config['num_workers'],\n",
        "    drop_last=False)\n",
        "\n",
        "log = OrderedDict([\n",
        "    ('epoch', []),\n",
        "    ('lr', []),\n",
        "    ('loss', []),\n",
        "    ('iou', []),\n",
        "    ('val_loss', []),\n",
        "    ('val_iou', []),\n",
        "])\n",
        "\n",
        "best_iou = 0\n",
        "trigger = 0\n",
        "for epoch in range(config['epochs']):\n",
        "    print('Epoch [%d/%d]' % (epoch, config['epochs']))\n",
        "\n",
        "    # train for one epoch\n",
        "    train_log = train(config, train_loader, model_analog, criterion, optimizer)\n",
        "    # evaluate on validation set\n",
        "    val_log = validate(config, val_loader, model_analog, criterion)\n",
        "\n",
        "    if config['scheduler'] == 'CosineAnnealingLR':\n",
        "        scheduler.step()\n",
        "    elif config['scheduler'] == 'ReduceLROnPlateau':\n",
        "        scheduler.step(val_log['loss'])\n",
        "\n",
        "    print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'\n",
        "          % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))\n",
        "\n",
        "    log['epoch'].append(epoch)\n",
        "    log['lr'].append(config['lr'])\n",
        "    log['loss'].append(train_log['loss'])\n",
        "    log['iou'].append(train_log['iou'])\n",
        "    log['val_loss'].append(val_log['loss'])\n",
        "    log['val_iou'].append(val_log['iou'])\n",
        "\n",
        "    pd.DataFrame(log).to_csv('models/%s/log.csv' %\n",
        "                              config['name'], index=False)\n",
        "\n",
        "    trigger += 1\n",
        "\n",
        "    if val_log['iou'] > best_iou:\n",
        "        torch.save(model_analog.state_dict(), 'models/%s/model2.pth' %\n",
        "                    config['name'])\n",
        "        best_iou = val_log['iou']\n",
        "        print(\"=> saved best model\")\n",
        "        trigger = 0\n",
        "\n",
        "    # early stopping\n",
        "    if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:\n",
        "        print(\"=> early stopping\")\n",
        "        break\n",
        "\n",
        "    torch.cuda.empty_cache()\n",
        "\n"
      ],
      "metadata": {
        "id": "jMcMp1CNHlBv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Test the model after different days\n",
        "- We test the drift after 1 second, 1 hour, 1 day and 30 days\n"
      ],
      "metadata": {
        "id": "uaIvedzUhCfc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from collections import OrderedDict\n",
        "from tqdm import tqdm\n",
        "\n",
        "def test_inference(config, model, criterion, test_loader):\n",
        "    #model.eval()  # ensure the model is in evaluation mode\n",
        "\n",
        "    # Initializing metric trackers\n",
        "    avg_meters = {'loss': AverageMeter(),\n",
        "                  'iou': AverageMeter(),\n",
        "                  'accuracy': AverageMeter(),\n",
        "                  'error': AverageMeter()}\n",
        "\n",
        "    with torch.no_grad():\n",
        "        pbar = tqdm(total=len(test_loader))\n",
        "\n",
        "        # Simulation of inference at different times after training.\n",
        "        for t_inference in [1, 3600,3600*24, 3600*24*30]:  # Example: Simulate the drift for 1 day.\n",
        "            print(t_inference)\n",
        "            model.drift_analog_weights(t_inference)  # Apply the drift simulation if applicable.\n",
        "\n",
        "            for data_batch in test_loader:\n",
        "                # Adjust the unpacking to account for the third item in the batch.\n",
        "                images, labels, _ = data_batch  # The third item is ignored as in training.\n",
        "\n",
        "                images = images.cuda()\n",
        "                labels = labels.cuda()\n",
        "\n",
        "                # Compute model output\n",
        "                if config['deep_supervision']:\n",
        "                    outputs = model(images)\n",
        "                    loss = 0\n",
        "                    for output in outputs:\n",
        "                        loss += criterion(output, labels)\n",
        "                    loss /= len(outputs)\n",
        "                    iou = iou_score(outputs[-1], labels)\n",
        "                else:\n",
        "                    output = model(images)\n",
        "                    loss = criterion(output, labels)\n",
        "                    iou = iou_score(output, labels)\n",
        "\n",
        "                # Calculate accuracy and error\n",
        "                _, predicted = torch.max(output.data, 1)\n",
        "                total = labels.size(0)\n",
        "                correct = (predicted == labels).sum().item()\n",
        "                accuracy = correct / total\n",
        "                error = 1 - accuracy\n",
        "\n",
        "                # Update tracking variables\n",
        "                avg_meters['loss'].update(loss.item(), total)\n",
        "                avg_meters['iou'].update(iou, total)\n",
        "\n",
        "                pbar.update(1)\n",
        "\n",
        "            # Displaying statistics after inference\n",
        "            print(f'Inference Time: {t_inference: .2e} seconds')\n",
        "            print(f'Average Loss: {avg_meters[\"loss\"].avg:.4f}\\tAverage IoU: {avg_meters[\"iou\"].avg:.4f}')\n",
        "\n",
        "            pbar.close()\n",
        "\n",
        "            # Resetting the average meters for the next inference time point\n",
        "            for meter in avg_meters.values():\n",
        "                meter.reset()\n",
        "\n",
        "    return OrderedDict([('loss', avg_meters['loss'].avg),\n",
        "                        ('iou', avg_meters['iou'].avg)])\n",
        "\n",
        "test_inference(config, model_analog,criterion, val_loader )"
      ],
      "metadata": {
        "id": "6lnrY3jKH9cc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def dice_score(pred, target, epsilon=1e-6):\n",
        "    \"\"\"\n",
        "    Compute the Dice score.\n",
        "\n",
        "    Args:\n",
        "    - pred (torch.Tensor): the predicted tensor\n",
        "    - target (torch.Tensor): the ground truth tensor\n",
        "    - epsilon (float): a small value to avoid division by zero\n",
        "\n",
        "    Returns:\n",
        "    - dice (torch.Tensor): computed Dice score\n",
        "    \"\"\"\n",
        "    pred_flat = pred.contiguous().view(-1)\n",
        "    target_flat = target.contiguous().view(-1)\n",
        "\n",
        "    intersection = (pred_flat * target_flat).sum()\n",
        "    denominator = pred_flat.sum() + target_flat.sum()\n",
        "\n",
        "    dice = (2 * intersection + epsilon) / (denominator + epsilon)\n",
        "    return dice"
      ],
      "metadata": {
        "id": "zqWHg1ZoR2oK"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}