{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "collapsed_sections": [
        "Fnhbfud2jyxz",
        "JCeNkaOHpGhE",
        "Vui2tk-5FoxF",
        "uVhXmlY3q42W",
        "-NoJnMfBi3ye",
        "t1j9ACzhIMPs",
        "d2w5jl286m6P"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "#@title License\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "metadata": {
        "id": "4I8aXPHdPf9m"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Table of contents"
      ],
      "metadata": {
        "id": "jin3IPhsiW4L"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        ">[Table of contents](#scrollTo=jin3IPhsiW4L)\n",
        "\n",
        ">[External packages](#scrollTo=BMj0796JsHrw)\n",
        "\n",
        ">[Preamble](#scrollTo=jxHrA25jhYSN)\n",
        "\n",
        ">[Main implementations](#scrollTo=yvGcuHtshUpV)\n",
        "\n",
        ">>[CPU variants](#scrollTo=r-RIw7vofKoE)\n",
        "\n",
        ">>[GPU variants](#scrollTo=VnUiCbPnfOIN)\n",
        "\n",
        ">[Main experiments](#scrollTo=KzJPcoufhM1Z)\n",
        "\n",
        ">>[Gradient norm benchmarks on CPU/GPU](#scrollTo=Fnhbfud2jyxz)\n",
        "\n",
        ">>>[Experiment executors](#scrollTo=JCeNkaOHpGhE)\n",
        "\n",
        ">>>[Experiment printers](#scrollTo=Vui2tk-5FoxF)\n",
        "\n",
        ">>>[CPU benchmarks](#scrollTo=uVhXmlY3q42W)\n",
        "\n",
        ">>>[GPU benchmarks](#scrollTo=-NoJnMfBi3ye)\n",
        "\n",
        ">>[End-to-end benchmarks on GPU](#scrollTo=OlznSp5TEfz1)\n",
        "\n",
        ">>>[Opacus gradient samplers](#scrollTo=t1j9ACzhIMPs)\n",
        "\n",
        ">>>[Experiment executors](#scrollTo=d2w5jl286m6P)\n",
        "\n",
        ">>>[GPU benchmarks](#scrollTo=4-hvSid56o9I)\n",
        "\n",
        ">>>>[Both naive and fast clipping](#scrollTo=xFN5JIrfBpWe)\n",
        "\n",
        ">>>>[Fast clipping only](#scrollTo=4FJw63nw3tvg)\n",
        "\n",
        ">[Additional experiments](#scrollTo=UkF7cVd1fjLQ)\n",
        "\n",
        ">>[Gradient norm benchmarks on CPU/GPU](#scrollTo=tJqYxkKvfz9Y)\n",
        "\n",
        ">>>[GPU benchmarks](#scrollTo=USkSHgYMgq2l)\n",
        "\n"
      ],
      "metadata": {
        "colab_type": "toc",
        "id": "ErjRzeACiTDT"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# External packages"
      ],
      "metadata": {
        "id": "BMj0796JsHrw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install opacus"
      ],
      "metadata": {
        "collapsed": true,
        "id": "ILk_ag9-EqYK",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "fdda8508-65a7-493c-f675-c00f756afd0b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting opacus\n",
            "  Downloading opacus-1.5.3-py3-none-any.whl.metadata (8.4 kB)\n",
            "Collecting numpy<2.0,>=1.15 (from opacus)\n",
            "  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: torch>=2.0 in /usr/local/lib/python3.11/dist-packages (from opacus) (2.6.0+cu124)\n",
            "Requirement already satisfied: scipy>=1.2 in /usr/local/lib/python3.11/dist-packages (from opacus) (1.15.2)\n",
            "Requirement already satisfied: opt-einsum>=3.3.0 in /usr/local/lib/python3.11/dist-packages (from opacus) (3.4.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (3.18.0)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (4.13.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (2025.3.2)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (0.6.2)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (12.4.127)\n",
            "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=2.0->opacus)\n",
            "  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (3.2.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0->opacus) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0->opacus) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=2.0->opacus) (3.0.2)\n",
            "Downloading opacus-1.5.3-py3-none-any.whl (251 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.7/251.7 kB\u001b[0m \u001b[31m23.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.3/18.3 MB\u001b[0m \u001b[31m118.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m120.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m95.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m65.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, opacus\n",
            "  Attempting uninstall: nvidia-nvjitlink-cu12\n",
            "    Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
            "    Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-curand-cu12\n",
            "    Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
            "    Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
            "      Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
            "  Attempting uninstall: nvidia-cufft-cu12\n",
            "    Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
            "    Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
            "      Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
            "  Attempting uninstall: nvidia-cuda-runtime-cu12\n",
            "    Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
            "    Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-cupti-cu12\n",
            "    Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cublas-cu12\n",
            "    Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
            "    Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
            "      Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
            "  Attempting uninstall: numpy\n",
            "    Found existing installation: numpy 2.0.2\n",
            "    Uninstalling numpy-2.0.2:\n",
            "      Successfully uninstalled numpy-2.0.2\n",
            "  Attempting uninstall: nvidia-cusparse-cu12\n",
            "    Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
            "    Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
            "      Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
            "  Attempting uninstall: nvidia-cudnn-cu12\n",
            "    Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
            "    Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
            "      Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
            "  Attempting uninstall: nvidia-cusolver-cu12\n",
            "    Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
            "    Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
            "      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed numpy-1.26.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 opacus-1.5.3\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install nvidia-ml-py3"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "Pi_C8o6dmulf",
        "outputId": "a3733c89-d24a-4316-a6f4-629e61629ba8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting nvidia-ml-py3\n",
            "  Downloading nvidia-ml-py3-7.352.0.tar.gz (19 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Building wheels for collected packages: nvidia-ml-py3\n",
            "  Building wheel for nvidia-ml-py3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for nvidia-ml-py3: filename=nvidia_ml_py3-7.352.0-py3-none-any.whl size=19172 sha256=475611eaefa7e6e3781d2fb94c64240ff590c3ebb327d6f680903ed2e9813395\n",
            "  Stored in directory: /root/.cache/pip/wheels/47/50/9e/29dc79037d74c3c1bb4a8661fb608e8674b7e4260d6a3f8f51\n",
            "Successfully built nvidia-ml-py3\n",
            "Installing collected packages: nvidia-ml-py3\n",
            "Successfully installed nvidia-ml-py3-7.352.0\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "pynvml"
                ]
              },
              "id": "78c223ededf244fbb09b781544011b39"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Preamble"
      ],
      "metadata": {
        "id": "jxHrA25jhYSN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import math\n",
        "import re\n",
        "import subprocess\n",
        "import numpy as np\n",
        "import time\n",
        "import cupy as cp\n",
        "from collections import defaultdict\n",
        "from typing import Sequence, Mapping, Callable, Any, List, Dict, Optional\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from functools import lru_cache\n",
        "from numba import cuda\n",
        "import psutil\n",
        "import os\n",
        "import tracemalloc\n",
        "import platform\n",
        "import distro\n",
        "import gc"
      ],
      "metadata": {
        "id": "KsvdMCB6R2s0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import nvidia_smi\n",
        "import opacus\n",
        "from opacus.grad_sample.utils import register_norm_sampler\n",
        "from opacus.utils.per_sample_gradients_utils import clone_module\n",
        "from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping\n",
        "from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping\n",
        "from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping"
      ],
      "metadata": {
        "id": "JQFAqXTzg0V6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Check GPU availability\n",
        "\n",
        "@lru_cache()\n",
        "def is_gpu_available() -> bool:\n",
        "    \"\"\"Check if nvidia-smi is available.\"\"\"\n",
        "    try:\n",
        "        nvidia_smi.nvmlInit()\n",
        "        nvidia_smi.nvmlShutdown()\n",
        "    except Exception:\n",
        "        return False\n",
        "    return True\n",
        "\n",
        "is_gpu_available()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7z3BWf9GnbJc",
        "outputId": "8484af48-bc43-4b70-8622-ce2a50230aac"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def get_processor_name():\n",
        "    if platform.system() == \"Windows\":\n",
        "        return platform.processor()\n",
        "    elif platform.system() == \"Darwin\":\n",
        "        os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'\n",
        "        command =\"sysctl -n machdep.cpu.brand_string\"\n",
        "        return subprocess.check_output(command).strip()\n",
        "    elif platform.system() == \"Linux\":\n",
        "        command = \"cat /proc/cpuinfo\"\n",
        "        all_info = subprocess.check_output(command, shell=True).decode().strip()\n",
        "        for line in all_info.split(\"\\n\"):\n",
        "            if \"model name\" in line:\n",
        "                return re.sub( \".*model name.*:\", \"\", line,1)\n",
        "    return \"\""
      ],
      "metadata": {
        "id": "lvYbbj5OM0to"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Check system info\n",
        "def get_size(bytes, suffix=\"B\"):\n",
        "  factor = 1000\n",
        "  for unit in [\"\", \"K\", \"M\", \"G\", \"T\", \"P\"]:\n",
        "    if bytes < factor:\n",
        "      return f\"{bytes:.2f}{unit}{suffix}\"\n",
        "    bytes /= factor\n",
        "\n",
        "uname = platform.uname()\n",
        "print(f\"System: {uname.system}\")\n",
        "print(f\"Distro: {distro.name()} {distro.version()}\")\n",
        "# print(f\"Node Name: {uname.node}\")\n",
        "# print(f\"Release: {uname.release}\")\n",
        "# print(f\"Version: {uname.version}\")\n",
        "# print(f\"Machine: {uname.machine}\")\n",
        "print(f\"Processor: {get_processor_name()}\")\n",
        "# # number of cores\n",
        "# print(\"Physical cores:\", psutil.cpu_count(logical=False))\n",
        "# print(\"Total cores:\", psutil.cpu_count(logical=True))\n",
        "svmem = psutil.virtual_memory()\n",
        "print(f\"Total RAM: {get_size(svmem.total)}\")\n",
        "print(f\"GPU: {torch.cuda.get_device_name()}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mt3pGJGxmi-e",
        "outputId": "5d153685-0a7b-4908-fb3d-cfef9fbc549a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "System: Linux\n",
            "Distro: Ubuntu 22.04\n",
            "Processor:  Intel(R) Xeon(R) CPU @ 2.00GHz\n",
            "Total RAM: 13.61GB\n",
            "GPU: Tesla T4\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Main implementations\n"
      ],
      "metadata": {
        "id": "yvGcuHtshUpV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CPU variants"
      ],
      "metadata": {
        "id": "r-RIw7vofKoE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"This is a library for fast and memory efficient gradient norm computation.\n",
        "\n",
        "The functions operate on a single Convolutional Neural Network (CNN) layer and\n",
        "take as input one sample of the batch and the corresponding partial gradient.\n",
        "The sample is assumed to be a 2D matrix whose rows correspond o the 1D vectors\n",
        "of the different input channels.The partial gradient is assumed to be a 2D\n",
        "matrix whose rows correspond to the 1D vectors of the different output channels.\n",
        "The other two args are the kernel size and stride of the layer.\n",
        "\"\"\"\n",
        "\n",
        "def _check_value_and_shape_of_arguments(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        "):\n",
        "  \"\"\"Checks the arguments of the functions in this library.\"\"\"\n",
        "  if input_matrix.ndim != 2:\n",
        "    raise ValueError(\"input_matrix must be a 2D matrix\")\n",
        "  if partial_gradient.ndim != 2:\n",
        "    raise ValueError(\"partial_gradient must be a 2D matrix\")\n",
        "  if input_matrix.shape[1] == 0:\n",
        "    raise ValueError(\"input_matrix must be non-empty\")\n",
        "  if partial_gradient.shape[1] == 0:\n",
        "    raise ValueError(\"partial_gradient must be non-empty\")\n",
        "  if kernel_size <= 0:\n",
        "    raise ValueError(\"kernel_size must be a positive integer\")\n",
        "  if stride <= 0:\n",
        "    raise ValueError(\"stride must be a positive integer\")\n",
        "  if (\n",
        "      # This is the formula for the output dimension of a CNN layer.\n",
        "      math.floor((input_matrix.shape[1] - kernel_size) / stride + 1)\n",
        "      != partial_gradient.shape[1]\n",
        "  ):\n",
        "    raise ValueError(\n",
        "        \"Number of columns of partial_gradient must be equal to the\"\n",
        "        \" output dimension of the layer\"\n",
        "    )\n",
        "\n",
        "\n",
        "def in_place_fast_grad_norm(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch.\n",
        "\n",
        "  This function avoids explicitly instantiating the intermediate matrix that is\n",
        "  used in the gradient norm computation. This is useful when the batch size is\n",
        "  large and the gradient norm is computed for each sample in the batch.\n",
        "\n",
        "  More formally, it implements the following logic: let x be the input matrix, g\n",
        "  be the partial gradient, U(x[i]) be the matrix whose rows correspond to\n",
        "  the different kernel windows of the i-th input channel, n_in be the number of\n",
        "  input channels, n_out be the number of output channels, and res be the\n",
        "  l_2 gradient norm squared. Then,\n",
        "  res = sum_{i in n_in} sum_{j in n_out} ||U(x[i])^T g[j]||^2.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    partial_gradient: 2D matrix whose rows are 1D vectors of the partial\n",
        "      gradient across the output channels.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "\n",
        "  # This function is checking that the values and shapes of args are valid.\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  res = 0\n",
        "\n",
        "  for input_vector in input_matrix:\n",
        "    # We use the sliding window view to avoid explicitly instantiating the\n",
        "    # intermediate matrix.\n",
        "    u_input_vector = np.lib.stride_tricks.sliding_window_view(\n",
        "        input_vector, window_shape=(kernel_size,)\n",
        "    )[::stride]\n",
        "    u_input_vector_transpose = u_input_vector.T\n",
        "    for output in partial_gradient:\n",
        "      res += np.sum(\n",
        "          np.square(np.tensordot(u_input_vector_transpose, output, (1, 0)))\n",
        "      )\n",
        "\n",
        "  return res\n",
        "\n",
        "\n",
        "def in_place_ghost_norm(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch.\n",
        "\n",
        "  This function uses the ghost norm trick to compute the gradient norm squared.\n",
        "  It avoids explicitly instantiating the intermediate matrices that are\n",
        "  used in the gradient norm computations. This is useful when the batch size is\n",
        "  large and the gradient norm is computed for each sample in the batch.\n",
        "\n",
        "  More formally, it implements the following logic: let x be the input matrix, g\n",
        "  be the partial gradient, U(x[i]) be the matrix whose rows correspond to\n",
        "  the different kernel windows of the i-th input channel, n_in be the number of\n",
        "  input channels, n_out be the number of output channels, and res be the\n",
        "  l_2 gradient norm squared. Then,\n",
        "  res = <sum_{i in n_in} U(x[i]) U(x[i])^T, sum_{j in n_out} g[j] (g[j])^T>,\n",
        "  where <,> is the Frobenius inner product.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    partial_gradient: 2D matrix whose rows are 1D vectors of the partial\n",
        "      gradient across the output channels.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "\n",
        "  # checking shapes and values of the arguments\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  output_dimension = partial_gradient.shape[1]\n",
        "  number_input_channels = input_matrix.shape[0]\n",
        "  number_output_channels = partial_gradient.shape[0]\n",
        "  res = 0\n",
        "\n",
        "  for j_1 in range(output_dimension):\n",
        "    for j_2 in range(output_dimension):\n",
        "      # This expression computes the following:\n",
        "      # sum_{i in n_in} (U(x[i]) U(x[i])^T)[j_1][j_2])\n",
        "      # Recall that the j_1 row of U(x[i]) corresponds to the kernel\n",
        "      # window of the i-th input channel at the j_1-th position.\n",
        "      temp_1 = 0\n",
        "      for i in range(number_input_channels):\n",
        "        temp_1 += np.dot(\n",
        "            input_matrix[i][j_1 * stride : j_1 * stride + kernel_size],\n",
        "            input_matrix[i][j_2 * stride : j_2 * stride + kernel_size],\n",
        "        )\n",
        "      # This expression computes the following:\n",
        "      # sum_{k in n_out} g[k](g[k])^T[j_1][j_2]\n",
        "      temp_2 = 0\n",
        "      for k in range(number_output_channels):\n",
        "        temp_2 += np.dot(partial_gradient[k][j_1], partial_gradient[k][j_2])\n",
        "      res += temp_1 * temp_2\n",
        "\n",
        "  return res\n",
        "\n",
        "\n",
        "def in_place_norm_fft(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch.\n",
        "\n",
        "  This function uses the Fast Fourier Transform to compute the gradient\n",
        "  norm squared, by efficiently transforming the gradient norm computation to a\n",
        "  multiplication of a circulant matrix with a vector. It avoids explicitly\n",
        "  instantiating the intermediate matrices and vectors that are used in the\n",
        "  computations. This is useful when the batch size is large and\n",
        "  the gradient norm is computed for each sample in the batch.\n",
        "\n",
        "  More formally, it implements the following logic: let x be the input matrix, g\n",
        "  be the partial gradient, U(x[i]) be the matrix whose rows correspond to\n",
        "  the different kernel windows of the i-th input channel, n_in be the number of\n",
        "  input channels, n_out be the number of output channels, and res be the l_2\n",
        "  gradient norm squared. Then,\n",
        "  res = sum_{i in n_in} sum_{j in n_out} ||R P U'(x[i])^T g'[j]||^2, where R is\n",
        "  an operator that returns some specific entries of the vector it is applied to,\n",
        "  P is an appropriate permutation matrix, U'(x[i]) is some circulantmatrix that\n",
        "  is defined based on U(x[i]), and g'[j] is the vector that is obtained by\n",
        "  padding g[j] appropriately.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    partial_gradient: 2D matrix whose rows are 1D vectors of the partial\n",
        "      gradient across the output channels.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  input_dimension = input_matrix.shape[1]\n",
        "  output_dimension = partial_gradient.shape[1]\n",
        "  number_input_channels = input_matrix.shape[0]\n",
        "  number_output_channels = partial_gradient.shape[0]\n",
        "  res = 0\n",
        "\n",
        "  # We only allocate memory for the entries of the padded partial gradient once\n",
        "  padded_partial_gradient = np.zeros(\n",
        "      input_dimension, dtype=partial_gradient.dtype\n",
        "  )\n",
        "\n",
        "  for j in range(number_output_channels):\n",
        "    # We start by populating the non-zero entries of the padded partial gradient\n",
        "    upper_bound = (output_dimension - 1) * stride + 1\n",
        "    padded_partial_gradient[:upper_bound:stride] = partial_gradient[j]\n",
        "\n",
        "    partial_derivative_fft = np.fft.fft(padded_partial_gradient)\n",
        "    for i in range(number_input_channels):\n",
        "\n",
        "      # FFT of the first column of the circulant matrix that is defined based\n",
        "      # on U(x[i])\n",
        "      vector_in_fft = np.fft.fft(np.flip(input_matrix[i]))\n",
        "\n",
        "      temp = np.flip(\n",
        "          np.fft.ifft(np.multiply(vector_in_fft, partial_derivative_fft))\n",
        "      )\n",
        "\n",
        "      res += np.sum(np.real(temp[0:kernel_size] ** 2))\n",
        "\n",
        "  return res\n",
        "\n",
        "\n",
        "def _unfold(\n",
        "    input_matrix: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        "    output_dimension: int,\n",
        "):\n",
        "  \"\"\"Unfolds the input matrix.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "    output_dimension: output dimension of the layer.\n",
        "\n",
        "  Returns:\n",
        "    A 2D matrix whose rows correspond to the different kernel windows of\n",
        "    all the input channels.\n",
        "  \"\"\"\n",
        "\n",
        "  input_channels = input_matrix.shape[0]\n",
        "\n",
        "  # Create slices of the input_matrix and stack them\n",
        "  # Create an array of indices to slice\n",
        "  indices = (\n",
        "      np.arange(kernel_size)[None, :]\n",
        "      + np.arange(output_dimension)[:, None] * stride\n",
        "  )\n",
        "\n",
        "  # Extract slices for each channel\n",
        "  slices = np.stack(\n",
        "      [input_matrix[j, indices] for j in range(input_channels)], axis=-1\n",
        "  )\n",
        "  res = slices.reshape(output_dimension, -1)\n",
        "  return res\n",
        "\n",
        "\n",
        "def naive_fast_grad_norm(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch.\n",
        "\n",
        "  This function is a memory inefficient implementation of the gradient norm\n",
        "  computation. It can be faster than the in-place implementations of this\n",
        "  function, but it requires more memory.\n",
        "\n",
        "  More formally, it implements the following logic: let x be the input matrix, g\n",
        "  be the partial gradient, U be the matrix whose i-th row consists of\n",
        "  consecutive blocks of the i-th kernel windows of all input channels, n_in be\n",
        "  the number of input channels, n_out be the number of output channels, and res\n",
        "  be the l_2 gradient norm squared. Then,\n",
        "  res = sum_{i in n_in} sum_{j in n_out} ||U(x[i])^T g[j]||^2.\n",
        "\n",
        "  Unlike the in-place implementations, this function explicitly instantiates the\n",
        "  intermediate matrix that is used in the gradient norm computation.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    partial_gradient: 2D matrix whose rows are 1D vectors of the partial\n",
        "      gradient across the output channels.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  unfolded_input_matrix = _unfold(\n",
        "      input_matrix, kernel_size, stride, partial_gradient.shape[1]\n",
        "  )\n",
        "  grad = unfolded_input_matrix.T @ partial_gradient.T\n",
        "  norm_grad = (\n",
        "      np.einsum(\"ij,ij->\", grad, grad)\n",
        "  )\n",
        "\n",
        "  return norm_grad\n",
        "\n",
        "\n",
        "def naive_ghost_norm(\n",
        "    input_matrix: np.ndarray,\n",
        "    partial_gradient: np.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch.\n",
        "\n",
        "  This function is a memory inefficient implementation of the gradient norm\n",
        "  computation. It can be faster than the in-place implementations of this\n",
        "  function, but it requires more memory.\n",
        "\n",
        "  More formally, it implements the following logic: let x be the input matrix, g\n",
        "  be the partial gradient, U be the matrix whose i-th row consists of\n",
        "  consecutive blocks of the i-th kernel windows of all input channels, n_in be\n",
        "  the number of input channels, n_out be the number of output channels, and res\n",
        "  be the l_2 gradient norm squared. Then, res = <U U^T, g g^T>, where <,> is the\n",
        "  Frobenius inner product.\n",
        "\n",
        "  Unlike the in-place implementations, this function explicitly instantiates the\n",
        "  intermediate matrices that are used in the gradient norm computation.\n",
        "\n",
        "  Args:\n",
        "    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.\n",
        "    partial_gradient: 2D matrix whose rows are 1D vectors of the partial\n",
        "      gradient across the output channels.\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  # computation of UU^T\n",
        "  unfolded_input_matrix = _unfold(\n",
        "      input_matrix, kernel_size, stride, partial_gradient.shape[1]\n",
        "  )\n",
        "  v = unfolded_input_matrix @ unfolded_input_matrix.T\n",
        "\n",
        "  # computation of gg^T\n",
        "  partial_matrix = partial_gradient.T @ partial_gradient\n",
        "\n",
        "  norm_grad = np.tensordot(v, partial_matrix, axes=[[0, 1], [0, 1]])\n",
        "  return norm_grad\n"
      ],
      "metadata": {
        "id": "REt4I4g9Knah"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## GPU variants"
      ],
      "metadata": {
        "id": "VnUiCbPnfOIN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def cp_in_place_fast_grad_norm(\n",
        "    input_matrix: cp.ndarray,\n",
        "    partial_gradient: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"Computes the gradient norm squared of a single sample in a batch (GPU version).\n",
        "\n",
        "  This function avoids explicitly instantiating the intermediate matrix that is\n",
        "  used in the gradient norm computation by leveraging CuPy's capabilities.\n",
        "\n",
        "  Args:\n",
        "      input_matrix: 2D CuPy array whose rows are 1D vectors of the input to the\n",
        "        layer.\n",
        "      partial_gradient: 2D CuPy array whose rows are 1D vectors of the partial\n",
        "        gradient across the output channels.\n",
        "      kernel_size: kernel size of the layer.\n",
        "      stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "      l_2 norm squared of the gradient as a float (on the CPU).\n",
        "  \"\"\"\n",
        "\n",
        "  # Assuming _check_value_and_shape_of_arguments can handle CuPy arrays\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix, partial_gradient, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  res_gpu = cp.zeros((1,), dtype=input_matrix.dtype)  # Initialize on the GPU\n",
        "\n",
        "  for input_vector in input_matrix:\n",
        "    # CuPy's stride_tricks is similar to NumPy's\n",
        "    u_input_vector = cp.lib.stride_tricks.sliding_window_view(\n",
        "        input_vector, window_shape=(kernel_size,)\n",
        "    )[::stride]\n",
        "    u_input_vector_transpose = u_input_vector.T\n",
        "    for output in partial_gradient:\n",
        "      res_gpu += cp.sum(\n",
        "          cp.square(cp.tensordot(u_input_vector_transpose, output, (1, 0)))\n",
        "      )\n",
        "\n",
        "  return res_gpu.item()  # Transfer the scalar result back to the CPU\n",
        "\n",
        "def cp_in_place_ghost_norm(\n",
        "    input_matrix_cp: cp.ndarray,\n",
        "    partial_gradient_cp: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"\n",
        "  Computes the gradient norm squared using CuPy for GPU acceleration.\n",
        "\n",
        "  Args:\n",
        "    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).\n",
        "    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "  # --- Input Validation ---\n",
        "  # It's often better to validate shapes *before* potential large transfers\n",
        "  # to GPU. If inputs are already on GPU, validate directly.\n",
        "  # Assuming _check_value_and_shape_of_arguments works with CuPy or\n",
        "  # you perform checks beforehand.\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix_cp, partial_gradient_cp, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  # --- Get Dimensions ---\n",
        "  number_input_channels, input_dim = input_matrix_cp.shape\n",
        "  number_output_channels, output_dimension = partial_gradient_cp.shape\n",
        "\n",
        "  # Calculate the expected input dimension based on output dim, kernel, stride\n",
        "  # This helps determine the shape for as_strided\n",
        "  expected_input_dim = (output_dimension - 1) * stride + kernel_size\n",
        "  if input_dim != expected_input_dim:\n",
        "       # Adjust if padding is involved, or raise error if inconsistent\n",
        "       raise ValueError(f\"Input dimension {input_dim} does not match expected \"\n",
        "                        f\"dimension {expected_input_dim} based on output_dim, \"\n",
        "                        f\"kernel_size, and stride.\")\n",
        "\n",
        "\n",
        "  # --- Vectorize temp_2 Calculation ---\n",
        "  # temp_2_matrix[j1, j2] = sum_{k} g[k][j1] * g[k][j2]\n",
        "  # This is equivalent to partial_gradient.T @ partial_gradient\n",
        "  temp_2_matrix = cp.matmul(partial_gradient_cp.T, partial_gradient_cp)\n",
        "  # Shape: (output_dimension, output_dimension)\n",
        "\n",
        "  # --- Vectorize temp_1 Calculation ---\n",
        "  # We need temp_1_matrix[j1, j2] = sum_{i} dot(patch(x[i], j1), patch(x[i], j2))\n",
        "  # 1. Create the patch tensor P using as_strided\n",
        "  #    P[i, j, k] = input_matrix_cp[i, j*stride + k]\n",
        "  shape = (number_input_channels, output_dimension, kernel_size)\n",
        "  # Calculate strides for the view:\n",
        "  # stride_i: distance between elements along axis 0 (input channels)\n",
        "  # stride_j: distance between elements along axis 1 (output patches)\n",
        "  # stride_k: distance between elements along axis 2 (kernel elements)\n",
        "  itemsize = input_matrix_cp.itemsize\n",
        "  stride_i = input_matrix_cp.strides[0]\n",
        "  stride_j = stride * input_matrix_cp.strides[1] # stride * itemsize if flat\n",
        "  stride_k = input_matrix_cp.strides[1]          # itemsize if flat\n",
        "  strides = (stride_i, stride_j, stride_k)\n",
        "\n",
        "  # Create the strided view (no data copied)\n",
        "  patches = cp.lib.stride_tricks.as_strided(input_matrix_cp, shape=shape, strides=strides)\n",
        "  # Shape: (n_in, output_dim, kernel_size)\n",
        "\n",
        "  # 2. Compute the sum of outer products: sum_i (P[i] @ P[i].T)\n",
        "  #    where P[i] has shape (output_dim, kernel_size)\n",
        "  #    Result matrix A[j1, j2] = sum_i dot(P[i, j1, :], P[i, j2, :])\n",
        "  #    Can use einsum: 'ijk,ilk->jl' sums over i (channels) and k (kernel)\n",
        "  temp_1_matrix = cp.einsum('ijk,ilk->jl', patches, patches, optimize='optimal')\n",
        "  # Shape: (output_dimension, output_dimension)\n",
        "\n",
        "  # --- Final Calculation (Frobenius Inner Product) ---\n",
        "  # res = sum_{j1, j2} temp_1_matrix[j1, j2] * temp_2_matrix[j1, j2]\n",
        "  res_cp = cp.sum(temp_1_matrix * temp_2_matrix)\n",
        "\n",
        "  return res_cp.item()\n",
        "\n",
        "import cupy as cp\n",
        "\n",
        "def cp_unfold(\n",
        "    input_matrix_cp: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        "    output_dimension: int,\n",
        ") -> cp.ndarray:\n",
        "  \"\"\"\n",
        "  Unfolds the input matrix using CuPy, optimized with as_strided.\n",
        "\n",
        "  Args:\n",
        "    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "    output_dimension: output dimension of the layer.\n",
        "\n",
        "  Returns:\n",
        "    A 2D CuPy matrix (output_dimension, n_in * kernel_size) whose rows\n",
        "    correspond to the different kernel windows of all the input channels.\n",
        "  \"\"\"\n",
        "  number_input_channels = input_matrix_cp.shape[0]\n",
        "  input_dim = input_matrix_cp.shape[1] # Get actual input dim\n",
        "\n",
        "  # Calculate expected input dim for validation if needed (optional here)\n",
        "  # expected_input_dim = (output_dimension - 1) * stride + kernel_size\n",
        "  # if input_dim < expected_input_dim: # Check if input is large enough\n",
        "  #     raise ValueError(...)\n",
        "\n",
        "  # 1. Create the patch tensor view P using as_strided\n",
        "  #    P[i, j, k] = input_matrix_cp[i, j*stride + k]\n",
        "  shape_view = (number_input_channels, output_dimension, kernel_size)\n",
        "\n",
        "  # Calculate strides for the view\n",
        "  stride_i = input_matrix_cp.strides[0] # Stride between channels\n",
        "  stride_j = stride * input_matrix_cp.strides[1] # Stride between patches start\n",
        "  stride_k = input_matrix_cp.strides[1] # Stride within a patch (along input_dim)\n",
        "  strides = (stride_i, stride_j, stride_k)\n",
        "\n",
        "  # Create the strided view (no data copied if possible)\n",
        "  patches_view = cp.lib.stride_tricks.as_strided(\n",
        "      input_matrix_cp, shape=shape_view, strides=strides\n",
        "  )\n",
        "  # Shape: (n_in, output_dim, kernel_size)\n",
        "\n",
        "  # 2. Transpose and reshape to match the original _unfold output format\n",
        "  # Target shape: (output_dimension, n_in * kernel_size)\n",
        "  # Transpose: (output_dim, n_in, kernel_size)\n",
        "  # Reshape: (output_dim, n_in * kernel_size)\n",
        "  # Note: transpose and reshape might create a copy if memory isn't contiguous\n",
        "  unfolded_matrix = patches_view.transpose(1, 0, 2).reshape(output_dimension, -1)\n",
        "\n",
        "  return unfolded_matrix\n",
        "\n",
        "import cupy as cp\n",
        "import cupy.fft\n",
        "\n",
        "# Assuming _check_value_and_shape_of_arguments is defined elsewhere\n",
        "# and works with CuPy arrays or is called before GPU transfer.\n",
        "# from your_module import _check_value_and_shape_of_arguments\n",
        "\n",
        "def cp_in_place_norm_fft(\n",
        "    input_matrix_cp: cp.ndarray,\n",
        "    partial_gradient_cp: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"\n",
        "  Computes the gradient norm squared using CuPy FFT for GPU acceleration.\n",
        "\n",
        "  Args:\n",
        "    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).\n",
        "    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "  # --- Input Validation ---\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix_cp, partial_gradient_cp, kernel_size, stride\n",
        "  )\n",
        "\n",
        "  # --- Get Dimensions ---\n",
        "  number_input_channels, input_dimension = input_matrix_cp.shape\n",
        "  number_output_channels, output_dimension = partial_gradient_cp.shape\n",
        "\n",
        "  # Ensure data type supports complex numbers for FFT\n",
        "  if not cp.can_cast(input_matrix_cp.dtype, cp.complex64):\n",
        "      input_matrix_cp = input_matrix_cp.astype(cp.float32) # Use float32 base\n",
        "  if not cp.can_cast(partial_gradient_cp.dtype, cp.complex64):\n",
        "       partial_gradient_cp = partial_gradient_cp.astype(cp.float32)\n",
        "\n",
        "  # --- Vectorized Padding ---\n",
        "  padded_partial_gradient_batch = cp.zeros(\n",
        "      (number_output_channels, input_dimension), dtype=partial_gradient_cp.dtype\n",
        "  )\n",
        "  upper_bound = (output_dimension - 1) * stride + 1\n",
        "  if upper_bound > input_dimension:\n",
        "       raise ValueError(\"Input dimension too small for output dim/stride/kernel.\")\n",
        "  padded_partial_gradient_batch[:, :upper_bound:stride] = partial_gradient_cp\n",
        "\n",
        "  # --- Batch FFTs ---\n",
        "  # FFT of padded gradients (batch along axis 0)\n",
        "  partial_derivative_fft_batch = cp.fft.fft(padded_partial_gradient_batch, axis=1)\n",
        "  # Shape: (n_out, input_dim)\n",
        "\n",
        "  # FFT of flipped input channels (batch along axis 0)\n",
        "  flipped_input = cp.flip(input_matrix_cp, axis=1)\n",
        "  vector_in_fft_batch = cp.fft.fft(flipped_input, axis=1)\n",
        "  # Shape: (n_in, input_dim)\n",
        "\n",
        "  # --- Combine FFTs (Broadcasting) ---\n",
        "  # vector_in_fft_batch shape: (n_in, 1, input_dim)\n",
        "  # partial_derivative_fft_batch shape: (1, n_out, input_dim)\n",
        "  multiplied_ffts = (\n",
        "      vector_in_fft_batch[:, None, :] * partial_derivative_fft_batch[None, :, :]\n",
        "  )\n",
        "  # Shape: (n_in, n_out, input_dim)\n",
        "\n",
        "  # --- Batch IFFT ---\n",
        "  temp_batch = cp.fft.ifft(multiplied_ffts, axis=2)\n",
        "  # Shape: (n_in, n_out, input_dim)\n",
        "\n",
        "  # --- Flip and Extract Relevant Part ---\n",
        "  temp_flipped_batch = cp.flip(temp_batch, axis=2)\n",
        "  relevant_part = temp_flipped_batch[:, :, :kernel_size]\n",
        "  # Shape: (n_in, n_out, kernel_size)\n",
        "\n",
        "  # --- Calculate Final Result ---\n",
        "  # Sum of squares of the real part\n",
        "  res_cp = cp.sum(cp.real(relevant_part) ** 2)\n",
        "\n",
        "  # --- Return Result ---\n",
        "  return res_cp.item()\n",
        "\n",
        "import cupy as cp\n",
        "\n",
        "def cp_naive_fast_grad_norm(\n",
        "    input_matrix_cp: cp.ndarray,\n",
        "    partial_gradient_cp: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"\n",
        "  Computes gradient norm squared using explicit unfolding (CuPy version).\n",
        "\n",
        "  Memory inefficient but potentially faster for some GPU/problem sizes.\n",
        "\n",
        "  Args:\n",
        "    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).\n",
        "    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "  # --- Input Validation ---\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix_cp, partial_gradient_cp, kernel_size, stride\n",
        "  )\n",
        "  output_dimension = partial_gradient_cp.shape[1]\n",
        "\n",
        "  # --- Unfold Input ---\n",
        "  unfolded_input_matrix_cp = cp_unfold(\n",
        "      input_matrix_cp, kernel_size, stride, output_dimension\n",
        "  )\n",
        "  # Shape: (output_dim, n_in * kernel_size)\n",
        "\n",
        "  # --- Compute Gradient Matrix ---\n",
        "  # grad = U.T @ g.T\n",
        "  grad_cp = cp.matmul(unfolded_input_matrix_cp.T, partial_gradient_cp.T)\n",
        "  # Shape: (n_in * kernel_size, n_out)\n",
        "\n",
        "  # --- Compute Squared Norm ---\n",
        "  # Equivalent to cp.einsum(\"ij,ij->\", grad_cp, grad_cp) or cp.linalg.norm(grad_cp)**2\n",
        "  norm_grad_cp = cp.sum(grad_cp * grad_cp) # Often very efficient\n",
        "\n",
        "  # --- Return Result ---\n",
        "  return norm_grad_cp.item()\n",
        "\n",
        "import cupy as cp\n",
        "\n",
        "# Assuming _check_value_and_shape_of_arguments is defined elsewhere\n",
        "# Assuming cp_unfold is defined as above\n",
        "# from your_module import _check_value_and_shape_of_arguments, cp_unfold\n",
        "\n",
        "def cp_naive_ghost_norm(\n",
        "    input_matrix_cp: cp.ndarray,\n",
        "    partial_gradient_cp: cp.ndarray,\n",
        "    kernel_size: int,\n",
        "    stride: int,\n",
        ") -> float:\n",
        "  \"\"\"\n",
        "  Computes gradient norm squared via <UU^T, gg^T> (CuPy version).\n",
        "\n",
        "  Memory inefficient but potentially faster for some GPU/problem sizes.\n",
        "\n",
        "  Args:\n",
        "    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).\n",
        "    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).\n",
        "    kernel_size: kernel size of the layer.\n",
        "    stride: stride of the layer.\n",
        "\n",
        "  Returns:\n",
        "    l_2 norm squared of the gradient as a float.\n",
        "  \"\"\"\n",
        "  # --- Input Validation ---\n",
        "  _check_value_and_shape_of_arguments(\n",
        "      input_matrix_cp, partial_gradient_cp, kernel_size, stride\n",
        "  )\n",
        "  output_dimension = partial_gradient_cp.shape[1]\n",
        "\n",
        "  # --- Unfold Input ---\n",
        "  unfolded_input_matrix_cp = cp_unfold(\n",
        "      input_matrix_cp, kernel_size, stride, output_dimension\n",
        "  )\n",
        "  # Shape: (output_dim, n_in * kernel_size)\n",
        "\n",
        "  # --- Compute UU^T ---\n",
        "  # v = U @ U.T\n",
        "  v_cp = cp.matmul(unfolded_input_matrix_cp, unfolded_input_matrix_cp.T)\n",
        "  # Shape: (output_dim, output_dim)\n",
        "\n",
        "  # --- Compute gg^T ---\n",
        "  # partial_matrix = g.T @ g\n",
        "  partial_matrix_cp = cp.matmul(partial_gradient_cp.T, partial_gradient_cp)\n",
        "  # Shape: (output_dim, output_dim)\n",
        "\n",
        "  # --- Compute Frobenius Inner Product ---\n",
        "  # Equivalent to cp.tensordot(v_cp, partial_matrix_cp, axes=2)\n",
        "  norm_grad_cp = cp.sum(v_cp * partial_matrix_cp)\n",
        "\n",
        "  # --- Return Result ---\n",
        "  return norm_grad_cp.item()\n"
      ],
      "metadata": {
        "id": "kjB3cO4xlTOv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Main experiments"
      ],
      "metadata": {
        "id": "KzJPcoufhM1Z"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Gradient norm benchmarks on CPU/GPU"
      ],
      "metadata": {
        "id": "Fnhbfud2jyxz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Global constants\n",
        "NUM_REPEATS = 5\n",
        "STRIDE = 1\n",
        "N_CHANNEL = 3"
      ],
      "metadata": {
        "id": "uChFLVqHsXNu",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Experiment executors"
      ],
      "metadata": {
        "id": "JCeNkaOHpGhE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_setting_params(d, stride, setting_number):\n",
        "  \"\"\"Computes (d_in, d_out, d_k).\"\"\"\n",
        "  if setting_number == 1:\n",
        "    d_k = d // 2\n",
        "  elif setting_number == 2:\n",
        "    d_k = d - 13\n",
        "  elif setting_number == 3:\n",
        "    d_k = d\n",
        "  else:\n",
        "    return ValueError(f\"Unknown setting number {setting_number}\")\n",
        "\n",
        "  d_in = d\n",
        "  d_out = math.floor((d_in - d_k) / stride + 1) # output dimension\n",
        "  return d_in, d_out, d_k"
      ],
      "metadata": {
        "id": "i0iyZu0YmeYh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "RuntimeMap = Mapping[str, Mapping[int, Sequence[float]]]\n",
        "MemoryMap = Mapping[str, Mapping[int, Sequence[float]]]\n",
        "ResultMap = Mapping[str, Mapping[int, Sequence[float]]]\n",
        "\n",
        "def run_experiments(\n",
        "    functions_to_test: Mapping[str, Callable[..., Any]],\n",
        "    d_list: Sequence[int],\n",
        "    n_channel: int,\n",
        "    stride: int,\n",
        "    num_repeats: int,\n",
        "    setting_number: int,\n",
        "    print_results = False,\n",
        "    aggregation_function = np.median,\n",
        ") -> tuple[\n",
        "    RuntimeMap,\n",
        "    MemoryMap,  # CPU\n",
        "    MemoryMap,  # GPU\n",
        "    ResultMap,\n",
        "]:\n",
        "  \"\"\"Computes experiment timings across different dimensions.\n",
        "\n",
        "  Args:\n",
        "    functions_to_test: Mapping of which functions to benchmark\n",
        "    d_list: List of base dimensions.\n",
        "    n_channel: The number of input/output channels.\n",
        "    stride: Kernel stride when computing convolutions.\n",
        "    num_repeats: Number of experiment trials for a given base dimension.\n",
        "    setting: Which setting in our paper that we are running (values: {1, 2, 3}).\n",
        "    print_results: Whether to log results to the terminal.\n",
        "    aggregation_function: Aggregation function for trial runtimes when printing\n",
        "      to the terminal.\n",
        "\n",
        "  Returns:\n",
        "    A mapping from method name to another mapping from base dimension to the\n",
        "    list of runtimes/memory/results.\n",
        "  \"\"\"\n",
        "  nvidia_smi.nvmlInit()\n",
        "  assert nvidia_smi.nvmlDeviceGetCount() == 1  # assume single GPU\n",
        "  handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)\n",
        "\n",
        "  n_in = n_channel\n",
        "  n_out = n_channel\n",
        "\n",
        "  runtimes = defaultdict(lambda: defaultdict(list))\n",
        "  peak_rams = defaultdict(lambda: defaultdict(list))\n",
        "  delta_vrams = defaultdict(lambda: defaultdict(list))\n",
        "  results = defaultdict(lambda: defaultdict(list))\n",
        "  d_list = [int(d) for d in d_list]\n",
        "\n",
        "  for d in d_list:\n",
        "    print(f\"Running d = {d}...\")\n",
        "    for name in functions_to_test.keys():\n",
        "      for _ in range(num_repeats):\n",
        "\n",
        "        if name.startswith(\"cp_\"):\n",
        "          cp._default_memory_pool.free_all_blocks()\n",
        "          start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "\n",
        "        dim_in, dim_out, kernel_size = compute_setting_params(\n",
        "            d, stride, setting_number\n",
        "        )\n",
        "        vector_in = np.random.rand(n_in, dim_in)\n",
        "        partial_derivative = np.random.rand(n_out, dim_out)\n",
        "\n",
        "        # Use float32 for better GPU performance generally\n",
        "        dtype_np = np.float32\n",
        "        dtype_cp = cp.float32\n",
        "\n",
        "        # Create random data on CPU\n",
        "        input_matrix_np = np.random.rand(n_in, dim_in).astype(dtype_np)\n",
        "        partial_gradient_np = np.random.rand(n_out, dim_out).astype(dtype_np)\n",
        "\n",
        "        # Transfer data to GPU\n",
        "        input_matrix_cp = cp.asarray(input_matrix_np, dtype=dtype_cp)\n",
        "        partial_gradient_cp = cp.asarray(partial_gradient_np, dtype=dtype_cp)\n",
        "\n",
        "        # --- Run and Time Functions ---\n",
        "        np_args = [input_matrix_np, partial_gradient_np, kernel_size, stride]\n",
        "        cp_args = [input_matrix_cp, partial_gradient_cp, kernel_size, stride]\n",
        "\n",
        "        function = functions_to_test[name]\n",
        "        function_args = {name : np_args if 'np' in name else cp_args}\n",
        "\n",
        "        # Synchronize for accurate timing if CuPy function\n",
        "        if name.startswith(\"cp_\"):\n",
        "          cp.cuda.Stream.null.synchronize()\n",
        "\n",
        "        # Execution is done here.\n",
        "        start_time = time.time()\n",
        "        tracemalloc.start()\n",
        "\n",
        "        try:\n",
        "          results[name][d].append(function(*function_args[name]))\n",
        "          _, peak_local_ram = tracemalloc.get_traced_memory()\n",
        "          end_time = time.time()\n",
        "          # Synchronize again for CuPy\n",
        "          if name.startswith(\"cp_\"):\n",
        "            cp.cuda.Stream.null.synchronize()\n",
        "            end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "            cp._default_memory_pool.free_all_blocks()\n",
        "        except:\n",
        "          end_time = float(\"inf\")\n",
        "          peak_local_ram = float(\"inf\")\n",
        "          end_vram = float(\"inf\")\n",
        "\n",
        "        tracemalloc.stop()\n",
        "        runtimes[name][d].append(end_time - start_time)\n",
        "        peak_rams[name][d].append(peak_local_ram)\n",
        "\n",
        "        if name.startswith(\"cp_\"):\n",
        "          delta_vrams[name][d].append(end_vram - start_vram)\n",
        "\n",
        "\n",
        "    if print_results:\n",
        "      # --- Compare Results and Speedups (Example: compare FFT versions) ---\n",
        "      print(\"\\n[Time in ms]\")\n",
        "      for name in functions_to_test.keys():\n",
        "        print(f\"{name:<15} {aggregation_function(runtimes[name][d])*1e3:.6f}\")\n",
        "      print(\"\\n[Peak RAM in MB]\")\n",
        "      for name in functions_to_test.keys():\n",
        "        print(f\"{name:<15} {aggregation_function(peak_rams[name][d])/1e6:.6f}\")\n",
        "      print(\"\\n[Delta VRAM in MB]\")\n",
        "      for name in functions_to_test.keys():\n",
        "        print(f\"{name:<15} {aggregation_function(delta_vrams[name][d])/1e6:.6f}\")\n",
        "      print(\"\\n\")\n",
        "\n",
        "  nvidia_smi.nvmlShutdown()\n",
        "  return runtimes, peak_rams, delta_vrams, results"
      ],
      "metadata": {
        "id": "s0NLPKTMj0Mn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Experiment printers"
      ],
      "metadata": {
        "id": "Vui2tk-5FoxF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def print_table(\n",
        "    experiment_results: ResultMap | RuntimeMap | MemoryMap,\n",
        "    dvalues: Sequence[int],\n",
        "    aggregation_function = np.median,\n",
        "    sep=\",\"\n",
        "):\n",
        "  print(f\"d\", end=sep)\n",
        "  for d in dvalues:\n",
        "    print(f\"{d}\", end=sep)\n",
        "  print()\n",
        "  for k, v in experiment_results.items():\n",
        "    print(f\"{k}\", end=sep)\n",
        "    for d in dvalues:\n",
        "      print(f\"{aggregation_function(v[d])}\", end=sep)\n",
        "    print()"
      ],
      "metadata": {
        "id": "U_DUaAdJFsOG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### CPU benchmarks"
      ],
      "metadata": {
        "id": "uVhXmlY3q42W"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "CPU_DVALUES = [\n",
        "    100,\n",
        "    200,\n",
        "    400,\n",
        "    800,\n",
        "    1_600,\n",
        "    3_200,\n",
        "    6_400,\n",
        "    12_800,\n",
        "    25_600,\n",
        "]\n",
        "\n",
        "# Which methods to benchmark.\n",
        "CPU_FUNCTIONS = {\n",
        "  \"np_fft\": in_place_norm_fft,\n",
        "  \"np_naive_fast\": naive_fast_grad_norm,\n",
        "  \"np_naive_ghost\": naive_ghost_norm,\n",
        "}\n",
        "\n",
        "# d_k = d / 2\n",
        "cpu_runtimes, cpu_peak_rams, cpu_delta_vrams, cpu_results = run_experiments(\n",
        "    CPU_FUNCTIONS,\n",
        "    CPU_DVALUES,\n",
        "    N_CHANNEL,\n",
        "    STRIDE,\n",
        "    NUM_REPEATS,\n",
        "    setting_number=1,\n",
        ")\n",
        "\n",
        "# For copying into Excel\n",
        "print(\"\\n[Runtime in seconds]\")\n",
        "print_table(cpu_runtimes, CPU_DVALUES)\n",
        "print(\"\\n[Peak RAM in bytes]\")\n",
        "print_table(cpu_peak_rams, CPU_DVALUES)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CQkKV-Auq4Av",
        "outputId": "a781a463-c39e-4028-d880-4c70814cf2b1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Running d = 100...\n",
            "Running d = 200...\n",
            "Running d = 400...\n",
            "Running d = 800...\n",
            "Running d = 1600...\n",
            "Running d = 3200...\n",
            "Running d = 6400...\n",
            "Running d = 12800...\n",
            "Running d = 25600...\n",
            "\n",
            "[Runtime in seconds]\n",
            "d,100,200,400,800,1600,3200,6400,12800,25600,\n",
            "np_fft,0.0008511543273925781,0.0005800724029541016,0.0010218620300292969,0.0012087821960449219,0.0016374588012695312,0.002469301223754883,0.003998756408691406,0.0074689388275146484,0.01009988784790039,\n",
            "np_naive_fast,0.0002238750457763672,0.0003342628479003906,0.0010647773742675781,0.0037941932678222656,0.013909578323364258,0.06909012794494629,0.27854347229003906,1.0922863483428955,24.85206389427185,\n",
            "np_naive_ghost,0.0005457401275634766,0.0010247230529785156,0.0019216537475585938,0.0066487789154052734,0.027704954147338867,0.1853630542755127,1.1202032566070557,7.788317680358887,78.77303791046143,\n",
            "\n",
            "[Peak RAM in bytes]\n",
            "d,100,200,400,800,1600,3200,6400,12800,25600,\n",
            "np_fft,9192.0,17592.0,34456.0,68120.0,135320.0,269720.0,538520.0,1076120.0,2151320.0,\n",
            "np_naive_fast,82920.0,324520.0,1287720.0,5134216.0,20507016.0,81974843.0,327787864.0,1310935874.0,5243346702.0,\n",
            "np_naive_ghost,82920.0,324520.0,1287720.0,5134216.0,20507016.0,81976159.0,327790782.0,1310944850.0,5243345551.0,\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### GPU benchmarks"
      ],
      "metadata": {
        "id": "-NoJnMfBi3ye"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "GPU_DVALUES = [\n",
        "    4_000,\n",
        "    8_000,\n",
        "    16_000,\n",
        "    32_000,\n",
        "    64_000,\n",
        "    128_000,\n",
        "    256_000,\n",
        "    512_000,\n",
        "    1_024_000,\n",
        "]\n",
        "\n",
        "# Which methods to benchmark.\n",
        "GPU_FUNCTIONS = {\n",
        "  \"cp_fft\": cp_in_place_norm_fft,\n",
        "  \"cp_naive_fast\": cp_naive_fast_grad_norm,\n",
        "  \"cp_naive_ghost\": cp_naive_ghost_norm,\n",
        "}\n",
        "\n",
        "# d_k = d / 2\n",
        "gpu_runtimes, gpu_peak_rams, gpu_delta_vrams, gpu_results = run_experiments(\n",
        "    GPU_FUNCTIONS,\n",
        "    GPU_DVALUES,\n",
        "    N_CHANNEL,\n",
        "    STRIDE,\n",
        "    NUM_REPEATS,\n",
        "    setting_number=1,\n",
        ")\n",
        "\n",
        "# For copying into Excel\n",
        "print(\"\\n[Runtime in seconds]\")\n",
        "print_table(gpu_runtimes, GPU_DVALUES)\n",
        "print(\"\\n[Peak RAM in bytes]\")\n",
        "print_table(gpu_peak_rams, GPU_DVALUES)\n",
        "print(\"\\n[Delta VRAM in bytes]\")\n",
        "print_table(gpu_delta_vrams, GPU_DVALUES)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "m8y2FyUAlf6R",
        "outputId": "b9f95ab7-3064-4d63-f421-0fea93925bb6",
        "collapsed": true
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Running d = 4000...\n",
            "Running d = 8000...\n",
            "Running d = 16000...\n",
            "Running d = 32000...\n",
            "Running d = 64000...\n",
            "Running d = 128000...\n",
            "Running d = 256000...\n",
            "Running d = 512000...\n",
            "Running d = 1024000...\n",
            "\n",
            "[Runtime in seconds]\n",
            "d,4000,8000,16000,32000,64000,128000,256000,512000,1024000,\n",
            "cp_fft,0.0020003318786621094,0.0021271705627441406,0.002496480941772461,0.0028612613677978516,0.0026824474334716797,0.0025663375854492188,0.0029802322387695312,0.007040739059448242,0.012893199920654297,\n",
            "cp_naive_fast,0.0016911029815673828,0.0029211044311523438,0.008950233459472656,0.03448319435119629,0.45615363121032715,inf,inf,inf,inf,\n",
            "cp_naive_ghost,0.019704341888427734,0.10161733627319336,0.9627914428710938,8.625009536743164,inf,inf,inf,inf,inf,\n",
            "\n",
            "[Peak RAM in bytes]\n",
            "d,4000,8000,16000,32000,64000,128000,256000,512000,1024000,\n",
            "cp_fft,8303.0,7103.0,7167.0,6815.0,6815.0,8519.0,8639.0,8015.0,7175.0,\n",
            "cp_naive_fast,5844.0,5724.0,5668.0,5668.0,43301.0,inf,inf,inf,inf,\n",
            "cp_naive_ghost,6876.0,8470.0,50407.0,51151.0,inf,inf,inf,inf,inf,\n",
            "\n",
            "[Delta VRAM in bytes]\n",
            "d,4000,8000,16000,32000,64000,128000,256000,512000,1024000,\n",
            "cp_fft,0.0,2097152.0,4194304.0,10485760.0,18874368.0,35651584.0,62914560.0,123731968.0,245366784.0,\n",
            "cp_naive_fast,48234496.0,192937984.0,769654784.0,3072327680.0,12293505024.0,inf,inf,inf,inf,\n",
            "cp_naive_ghost,98566144.0,387973120.0,1543503872.0,6148849664.0,inf,inf,inf,inf,inf,\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Debug only\n",
        "# =================================================================================\n",
        "# GPU_DVALUES = [\n",
        "#     64_000,\n",
        "#     128_000,\n",
        "#     256_000,\n",
        "#     512_000,\n",
        "#     1_024_000,\n",
        "# ]\n",
        "\n",
        "# # Which methods to benchmark.\n",
        "# GPU_FUNCTIONS = {\n",
        "#   \"cp_fft\": cp_in_place_norm_fft,\n",
        "# }\n",
        "\n",
        "# # d_k = d / 2\n",
        "# gpu_runtimes, gpu_peak_rams, gpu_delta_vrams, gpu_results = run_experiments(\n",
        "#     GPU_FUNCTIONS,\n",
        "#     GPU_DVALUES,\n",
        "#     N_CHANNEL,\n",
        "#     STRIDE,\n",
        "#     NUM_REPEATS,\n",
        "#     setting_number=1,\n",
        "# )\n",
        "\n",
        "# print(\"\\n[Runtime in seconds]\")\n",
        "# print_table(gpu_runtimes, GPU_DVALUES)\n",
        "# print(\"\\n[Peak RAM in bytes]\")\n",
        "# print_table(gpu_peak_rams, GPU_DVALUES)\n",
        "# print(\"\\n[Delta VRAM in bytes]\")\n",
        "# print_table(gpu_delta_vrams, GPU_DVALUES)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "jIE2GzFIoKop"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## End-to-end benchmarks on GPU"
      ],
      "metadata": {
        "id": "OlznSp5TEfz1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Opacus gradient samplers"
      ],
      "metadata": {
        "id": "t1j9ACzhIMPs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "@register_norm_sampler(nn.Conv1d)\n",
        "def compute_conv1d_norm_sample(\n",
        "    layer: nn.Conv1d,\n",
        "    inputs: List[torch.Tensor],\n",
        "    backprops: torch.Tensor,\n",
        ") -> Dict[nn.Parameter, torch.Tensor]:\n",
        "\n",
        "  stride = layer.stride[0]\n",
        "  kernel_size = layer.kernel_size[0]\n",
        "  input_matrix_cp = cp.asarray(inputs[0].numpy()[0, :, :])\n",
        "  backprops_cp = cp.asarray(backprops.numpy()[0, :, :])\n",
        "  number_input_channels, input_dimension = input_matrix_cp.shape\n",
        "  number_output_channels, output_dimension = backprops_cp.shape\n",
        "\n",
        "  if not cp.can_cast(input_matrix_cp.dtype, cp.complex64):\n",
        "    input_matrix_cp = input_matrix_cp.astype(cp.float32) # Use float32 base\n",
        "  if not cp.can_cast(backprops_cp.dtype, cp.complex64):\n",
        "    backprops_cp = backprops_cp.astype(cp.float32)\n",
        "\n",
        "  padded_partial_gradient_batch = cp.zeros(\n",
        "      (number_output_channels, input_dimension), dtype=backprops_cp.dtype\n",
        "  )\n",
        "  upper_bound = (output_dimension - 1) * stride + 1\n",
        "  padded_partial_gradient_batch[:, :upper_bound:stride] = backprops_cp\n",
        "  partial_derivative_fft_batch = cp.fft.fft(padded_partial_gradient_batch, axis=1)\n",
        "\n",
        "  flipped_input = cp.flip(input_matrix_cp, axis=1)\n",
        "  vector_in_fft_batch = cp.fft.fft(flipped_input, axis=1)\n",
        "  multiplied_ffts = (\n",
        "      vector_in_fft_batch[:, None, :] * partial_derivative_fft_batch[None, :, :]\n",
        "  )\n",
        "  temp_batch = cp.fft.ifft(multiplied_ffts, axis=2)\n",
        "  temp_flipped_batch = cp.flip(temp_batch, axis=2)\n",
        "  relevant_part = temp_flipped_batch[:, :, :kernel_size]\n",
        "  res_cp = cp.sum(cp.real(relevant_part) ** 2)\n",
        "  norms = torch.asarray([res_cp.item()])\n",
        "\n",
        "  return {layer.weight: torch.sqrt(norms)}"
      ],
      "metadata": {
        "id": "9OhZjf75Gand"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Unit test\n",
        "# n_in = 3\n",
        "# n_out = 4\n",
        "# d = 5\n",
        "# d_out = 7\n",
        "# batch_size = 1\n",
        "# layer1 = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=d, stride=1)\n",
        "# inputs1 = torch.randn(batch_size, n_in, d)\n",
        "# backprops1 = torch.randn(batch_size, n_out, d)\n",
        "# compute_conv1d_norm_sample(layer1, inputs1, backprops1)"
      ],
      "metadata": {
        "id": "waBHmzLDIFNW",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Experiment executors"
      ],
      "metadata": {
        "id": "d2w5jl286m6P"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class SampleConv1dModule(nn.Module):\n",
        "  # stride = 1\n",
        "  def __init__(self, n, d_k):\n",
        "    super(SampleConv1dModule, self).__init__()\n",
        "    self.conv1d = nn.Conv1d(n, n, d_k, stride=1, bias=False)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = self.conv1d(x)\n",
        "    return(x)"
      ],
      "metadata": {
        "id": "I3L9_eHaZGGJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def run_normal_dp_sgd(\n",
        "   num_iterations,\n",
        "   num_channels,\n",
        "   kernel_size,\n",
        "   input_dim,\n",
        "   batch_size = 1,\n",
        "   noise_multiplier = 1.0,\n",
        "   max_grad_norm = 1.0,\n",
        "   criterion = torch.nn.MSELoss(reduce=\"None\"),\n",
        "):\n",
        "  input_data = torch.rand((batch_size, num_channels, input_dim))\n",
        "  sample_module = SampleConv1dModule(num_channels, kernel_size)\n",
        "  model_normal = GradSampleModule(clone_module(sample_module))\n",
        "  optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)\n",
        "  optimizer_normal = DPOptimizer(\n",
        "      optimizer_normal,\n",
        "      noise_multiplier=noise_multiplier,\n",
        "      max_grad_norm=max_grad_norm,\n",
        "      expected_batch_size=batch_size,\n",
        "  )\n",
        "\n",
        "  t0 = time.time()\n",
        "  for _ in range(num_iterations):\n",
        "    optimizer_normal.zero_grad()\n",
        "    output_normal = model_normal(input_data)\n",
        "    target_data = torch.rand_like(output_normal)\n",
        "    loss_normal = torch.mean(criterion(output_normal, target_data))\n",
        "    loss_normal.backward()\n",
        "  return time.time() - t0"
      ],
      "metadata": {
        "id": "jdX92JMg5Zr_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a83a142e-6273-4f8c-f47d-ab9a6d6df135"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torch/nn/_reduction.py:51: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
            "  warnings.warn(warning.format(ret))\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def run_gc_dp_sgd(\n",
        "    num_iterations,\n",
        "    num_channels,\n",
        "    kernel_size,\n",
        "    input_dim,\n",
        "    batch_size = 1,\n",
        "    noise_multiplier = 1.0,\n",
        "    max_grad_norm = 1.0,\n",
        "    criterion = torch.nn.MSELoss(reduce=\"None\"),\n",
        "):\n",
        "  input_data = torch.rand((batch_size, num_channels, input_dim))\n",
        "  sample_module = SampleConv1dModule(num_channels, kernel_size)\n",
        "  model_gc = GradSampleModuleFastGradientClipping(\n",
        "      clone_module(sample_module),\n",
        "      max_grad_norm=max_grad_norm,\n",
        "      use_ghost_clipping=True,\n",
        "  )\n",
        "  optimizer_gc = torch.optim.SGD(model_gc.parameters(), lr=1)\n",
        "  optimizer_gc = DPOptimizerFastGradientClipping(\n",
        "      optimizer_gc,\n",
        "      noise_multiplier=noise_multiplier,\n",
        "      max_grad_norm=max_grad_norm,\n",
        "      expected_batch_size=batch_size,\n",
        "  )\n",
        "\n",
        "  t0 = time.time()\n",
        "  for i in range(num_iterations):\n",
        "    model_gc.enable_hooks()\n",
        "    output_gc = model_gc(input_data)\n",
        "    target_data = torch.rand_like(output_gc)\n",
        "    first_loss_per_sample = criterion(output_gc, target_data)\n",
        "    first_loss = torch.mean(first_loss_per_sample)\n",
        "    first_loss.backward(retain_graph=True)\n",
        "    optimizer_gc.zero_grad()\n",
        "    coeff = model_gc.get_clipping_coef()\n",
        "    second_loss_per_sample = coeff * first_loss_per_sample\n",
        "    second_loss = torch.sum(second_loss_per_sample)\n",
        "    model_gc.disable_hooks()\n",
        "    second_loss.backward()\n",
        "  return time.time() - t0"
      ],
      "metadata": {
        "id": "M13Dmy066k-L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def run_all_dp_sgd_executors(\n",
        "    num_iterations,\n",
        "    num_channels,\n",
        "    kernel_size,\n",
        "    input_dim,\n",
        "    batch_size = 1,\n",
        "    noise_multiplier = 1.0,\n",
        "    max_grad_norm = 1.0,\n",
        "):\n",
        "  nvidia_smi.nvmlInit()\n",
        "  assert nvidia_smi.nvmlDeviceGetCount() == 1  # assume single GPU\n",
        "\n",
        "  gc.collect()\n",
        "  torch.cuda.empty_cache()\n",
        "  cp._default_memory_pool.free_all_blocks()\n",
        "  try:\n",
        "    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)\n",
        "    start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "    dp_sgd_time = run_normal_dp_sgd(\n",
        "        num_iterations,\n",
        "        num_channels,\n",
        "        kernel_size,\n",
        "        input_dim,\n",
        "        batch_size,\n",
        "        noise_multiplier,\n",
        "        max_grad_norm,\n",
        "    )\n",
        "    end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "    dp_sgd_vram = end_vram - start_vram\n",
        "  except:\n",
        "    dp_sgd_time = float(\"inf\")\n",
        "    dp_sgd_vram = float(\"inf\")\n",
        "\n",
        "  gc.collect()\n",
        "  torch.cuda.empty_cache()\n",
        "  cp._default_memory_pool.free_all_blocks()\n",
        "  try:\n",
        "    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)\n",
        "    start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "    gc_dp_sgd_time = run_gc_dp_sgd(\n",
        "        num_iterations,\n",
        "        num_channels,\n",
        "        kernel_size,\n",
        "        input_dim,\n",
        "        batch_size,\n",
        "        noise_multiplier,\n",
        "        max_grad_norm,\n",
        "    )\n",
        "    end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used\n",
        "    gc_dp_sgd_vram = end_vram - start_vram\n",
        "  except:\n",
        "    gc_dp_sgd_time = float(\"inf\")\n",
        "    gc_dp_sgd_vram = float(\"inf\")\n",
        "\n",
        "  return dp_sgd_time, gc_dp_sgd_time, dp_sgd_vram, gc_dp_sgd_vram"
      ],
      "metadata": {
        "id": "HvM4eYIS8L4M"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### GPU benchmarks"
      ],
      "metadata": {
        "id": "4-hvSid56o9I"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Both naive and fast clipping"
      ],
      "metadata": {
        "id": "xFN5JIrfBpWe"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "GPU_E2E_DVALUES = [500, 1_000, 2_000, 4_000, 8_000]\n",
        "dp_sgd_times = []\n",
        "gc_dp_sgd_times = []\n",
        "dp_sgd_vrams = []\n",
        "gc_dp_sgd_vrams = []\n",
        "for d in GPU_E2E_DVALUES:\n",
        "  print(f\"d = {d}\")\n",
        "  t0, t1, m0, m1 = run_all_dp_sgd_executors(\n",
        "    num_iterations=5,\n",
        "    num_channels=1,\n",
        "    kernel_size=d//2,\n",
        "    input_dim=d,\n",
        "    batch_size=128,\n",
        "  )\n",
        "  dp_sgd_times.append(t0)\n",
        "  gc_dp_sgd_times.append(t1)\n",
        "  dp_sgd_vrams.append(m0)\n",
        "  gc_dp_sgd_vrams.append(m1)"
      ],
      "metadata": {
        "id": "o1ZRAJ3q6da-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"d,\", end=\"\")\n",
        "for d in GPU_E2E_DVALUES:\n",
        "  print(d, end=\",\")\n",
        "print(\"\\ndp_sgd_time,\", end=\"\")\n",
        "for t in dp_sgd_times:\n",
        "  print(t, end=\",\")\n",
        "print(\"\\ngc_dp_sgd_time,\", end=\"\")\n",
        "for t in gc_dp_sgd_times:\n",
        "  print(t, end=\",\")"
      ],
      "metadata": {
        "id": "KN9XoS2207rU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Fast clipping only\n",
        "\n",
        "d=16K for naive DP-SGD OOMs"
      ],
      "metadata": {
        "id": "4FJw63nw3tvg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "GPU_E2E_EXT_DVALUES = [16_000, 32_000, 64_000]\n",
        "gc_dp_sgd_ext_times = []\n",
        "for d in GPU_E2E_EXT_DVALUES:\n",
        "  print(f\"d = {d}\")\n",
        "  t = run_gc_dp_sgd(\n",
        "    num_iterations=5,\n",
        "    num_channels=1,\n",
        "    kernel_size=d//2,\n",
        "    input_dim=d,\n",
        "    batch_size=128,\n",
        "  )\n",
        "  gc_dp_sgd_ext_times.append(t)"
      ],
      "metadata": {
        "id": "B8x5tgPNzqFI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"d,\", end=\"\")\n",
        "for d in GPU_E2E_EXT_DVALUES:\n",
        "  print(d, end=\",\")\n",
        "print(\"\\ngc_dp_sgd_time,\", end=\"\")\n",
        "for t in gc_dp_sgd_ext_times:\n",
        "  print(t, end=\",\")"
      ],
      "metadata": {
        "id": "2AVVNoGw1pu7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Additional experiments"
      ],
      "metadata": {
        "id": "UkF7cVd1fjLQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Gradient norm benchmarks on CPU"
      ],
      "metadata": {
        "id": "tJqYxkKvfz9Y"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### CPU benchmarks"
      ],
      "metadata": {
        "id": "USkSHgYMgq2l"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Setting 1"
      ],
      "metadata": {
        "id": "mjTrS7EXI6-Q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "D_LIST1 = [400, 800, 1_600, 3_200, 6_400]\n",
        "\n",
        "# Which methods to benchmark.\n",
        "FUNCTIONS1 = {\n",
        "  \"np_fft\": in_place_norm_fft,\n",
        "  \"np_fast_grad\": in_place_fast_grad_norm,\n",
        "  # \"np_in_place\": in_place_ghost_norm,\n",
        "  \"np_naive_fast\": naive_fast_grad_norm,\n",
        "  \"np_naive_ghost\": naive_ghost_norm,\n",
        "}\n",
        "\n",
        "# Setting 1: d_k = d_in / 2\n",
        "print(\"Setting 1\\n\")\n",
        "(\n",
        "    cpu_runtimes_ex1,\n",
        "    cpu_peak_rams_ex1,\n",
        "    cpu_delta_vrams_ex1,\n",
        "    cpu_results_ex1,\n",
        ") = run_experiments(\n",
        "    FUNCTIONS1,\n",
        "    D_LIST1,\n",
        "    N_CHANNEL,\n",
        "    STRIDE,\n",
        "    NUM_REPEATS,\n",
        "    setting_number=1,\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6G8fpTmJG0xU",
        "outputId": "fae46236-cfd2-4427-e7b9-7e2c0cd9b9ef"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Setting 1\n",
            "\n",
            "Running d = 400...\n",
            "Running d = 800...\n",
            "Running d = 1600...\n",
            "Running d = 3200...\n",
            "Running d = 6400...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# For copying into Excel\n",
        "print(\"\\n[Runtime in seconds]\")\n",
        "print_table(cpu_runtimes_ex1, D_LIST1)\n",
        "print(\"\\n[Peak RAM in bytes]\")\n",
        "print_table(cpu_peak_rams_ex1, D_LIST1)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X90hiZLcHuNq",
        "outputId": "7d8b0159-6ad3-4ccd-d8c9-9eda4cb2ebf3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "[Runtime in seconds]\n",
            "d,400,800,1600,3200,6400,\n",
            "np_fft,0.0006990432739257812,0.0012874603271484375,0.0015435218811035156,0.0024542808532714844,0.003905773162841797,\n",
            "np_fast_grad,0.004522562026977539,0.005242824554443359,0.007205486297607422,0.013261556625366211,0.08928394317626953,\n",
            "np_naive_fast,0.0010628700256347656,0.0037775039672851562,0.014162540435791016,0.07788252830505371,0.29234766960144043,\n",
            "np_naive_ghost,0.0019783973693847656,0.006479740142822266,0.029323101043701172,0.19436955451965332,1.1241252422332764,\n",
            "\n",
            "[Peak RAM in bytes]\n",
            "d,400,800,1600,3200,6400,\n",
            "np_fft,34456.0,68120.0,135320.0,269720.0,538520.0,\n",
            "np_fast_grad,164319.0,646143.0,2569799.0,10255743.0,40993047.0,\n",
            "np_naive_fast,1287720.0,5134216.0,20507016.0,81973968.0,327792291.0,\n",
            "np_naive_ghost,1287720.0,5134216.0,20507168.0,81973088.0,327788552.0,\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Setting 2"
      ],
      "metadata": {
        "id": "VQkoPKUUI_F7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "D_LIST2 = [64_000, 128_000, 256_000, 512_000, 1_024_000]\n",
        "\n",
        "# Which methods to benchmark.\n",
        "FUNCTIONS2 = {\n",
        "  \"np_fft\": in_place_norm_fft,\n",
        "  \"np_fast_grad\": in_place_fast_grad_norm,\n",
        "  \"np_in_place\": in_place_ghost_norm,\n",
        "  \"np_naive_fast\": naive_fast_grad_norm,\n",
        "  \"np_naive_ghost\": naive_ghost_norm,\n",
        "}\n",
        "\n",
        "# Setting 2: d_k = d_in - 13\n",
        "print(\"Setting 2\\n\")\n",
        "(\n",
        "    cpu_runtimes_ex2,\n",
        "    cpu_peak_rams_ex2,\n",
        "    cpu_delta_vrams_ex2,\n",
        "    cpu_results_ex2,\n",
        ") = run_experiments(\n",
        "    FUNCTIONS2,\n",
        "    D_LIST2,\n",
        "    N_CHANNEL,\n",
        "    STRIDE,\n",
        "    NUM_REPEATS,\n",
        "    setting_number=2,\n",
        ")"
      ],
      "metadata": {
        "id": "acatAJfCUOzI",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "180cfe53-cc2f-4df3-8490-f841b0cb82d2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Setting 2\n",
            "\n",
            "Running d = 64000...\n",
            "Running d = 128000...\n",
            "Running d = 256000...\n",
            "Running d = 512000...\n",
            "Running d = 1024000...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# For copying into Excel\n",
        "print(\"\\n[Runtime in seconds]\")\n",
        "print_table(cpu_runtimes_ex2, D_LIST2)\n",
        "print(\"\\n[Peak RAM in bytes]\")\n",
        "print_table(cpu_peak_rams_ex2, D_LIST2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PiA3PCiQJTcn",
        "outputId": "9ee4fdd7-75d4-436a-a8a3-3842f2f42bdb"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "[Runtime in seconds]\n",
            "d,64000,128000,256000,512000,1024000,\n",
            "np_fft,0.04720807075500488,0.06013035774230957,0.14534854888916016,0.32793593406677246,0.8023104667663574,\n",
            "np_fast_grad,0.029539108276367188,0.02435159683227539,0.04384875297546387,0.20948457717895508,0.27891993522644043,\n",
            "np_in_place,0.044385433197021484,0.05373072624206543,0.06265926361083984,0.11540675163269043,0.18042659759521484,\n",
            "np_naive_fast,0.019695281982421875,0.04124641418457031,0.08883976936340332,0.17839574813842773,0.38510656356811523,\n",
            "np_naive_ghost,0.023937225341796875,0.03459930419921875,0.06670308113098145,0.16159534454345703,0.3193166255950928,\n",
            "\n",
            "[Peak RAM in bytes]\n",
            "d,64000,128000,256000,512000,1024000,\n",
            "np_fft,5377312.0,10753312.0,21511520.0,43019901.0,86037512.0,\n",
            "np_fast_grad,3842419.0,7682123.0,15363098.0,30729138.0,61453472.0,\n",
            "np_in_place,852.0,3553.0,3108.0,3388.0,8316.0,\n",
            "np_naive_fast,28667680.0,57339680.0,114684040.0,229378768.0,458757838.0,\n",
            "np_naive_ghost,28667528.0,57339528.0,114683888.0,229378824.0,458757814.0,\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Setting 3"
      ],
      "metadata": {
        "id": "UkEu4ocmJhA7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "D_LIST3 = [10]\n",
        "N_LIST = [40, 80, 160, 320, 640]\n",
        "\n",
        "# Which methods to benchmark.\n",
        "FUNCTIONS3 = {\n",
        "  \"np_fft\": in_place_norm_fft,\n",
        "  # \"np_fast_grad\": in_place_fast_grad_norm,\n",
        "  \"np_in_place\": in_place_ghost_norm,\n",
        "  \"np_naive_fast\": naive_fast_grad_norm,\n",
        "  \"np_naive_ghost\": naive_ghost_norm,\n",
        "}\n",
        "\n",
        "# Setting 3: variable n\n",
        "print(\"Setting 3\\n\")\n",
        "runtimes_ex3 = []\n",
        "peak_rams_ex3 = []\n",
        "for n in N_LIST:\n",
        "  print(f\"n = {n}\")\n",
        "  (\n",
        "      cpu_runtimes_ex3,\n",
        "      cpu_peak_rams_ex3,\n",
        "      _,\n",
        "      _,\n",
        "  ) = run_experiments(\n",
        "      FUNCTIONS3,\n",
        "      D_LIST3,\n",
        "      n,\n",
        "      STRIDE,\n",
        "      NUM_REPEATS,\n",
        "      setting_number=3\n",
        "  )\n",
        "  runtimes_ex3.append(cpu_runtimes_ex3)\n",
        "  peak_rams_ex3.append(cpu_peak_rams_ex3)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "68c454fc-d367-43ed-9c58-02bfd50c773a",
        "id": "0FMgeUFpJhA7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Setting 3\n",
            "\n",
            "n = 40\n",
            "Running d = 10...\n",
            "n = 80\n",
            "Running d = 10...\n",
            "n = 160\n",
            "Running d = 10...\n",
            "n = 320\n",
            "Running d = 10...\n",
            "n = 640\n",
            "Running d = 10...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def print_channel_table(\n",
        "    experiment_results: ResultMap | RuntimeMap | MemoryMap,\n",
        "    n_values: Sequence[int],\n",
        "    d_value: int,\n",
        "    aggregation_function = np.median,\n",
        "    sep=\",\"\n",
        "):\n",
        "  print(f\"n\", end=sep)\n",
        "  for n in n_values:\n",
        "    print(f\"{n}\", end=sep)\n",
        "  print()\n",
        "  methods = list(experiment_results[0].keys())\n",
        "  for m in methods:\n",
        "    print(m, end=sep)\n",
        "    for e in experiment_results:\n",
        "      print(f\"{aggregation_function(e[m][d_value])}\", end=sep)\n",
        "    print()"
      ],
      "metadata": {
        "id": "nwNGNU2OLBsV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# For copying into Excel\n",
        "print(\"\\n[Runtime in seconds]\")\n",
        "print_channel_table(runtimes_ex3, N_LIST, D_LIST3[0])\n",
        "print(\"\\n[Peak RAM in bytes]\")\n",
        "print_channel_table(peak_rams_ex3, N_LIST, D_LIST3[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "68f0f81b-5012-42b1-9294-1ecd39bd7c71",
        "id": "SiMIbxoFJhA8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "[Runtime in seconds]\n",
            "n,40,80,160,320,640,\n",
            "np_fft,0.11008262634277344,0.29825496673583984,1.061675786972046,4.766016721725464,19.153496503829956,\n",
            "np_in_place,0.003210306167602539,0.0034322738647460938,0.006872653961181641,0.013518571853637695,0.034094810485839844,\n",
            "np_naive_fast,0.0008285045623779297,0.0008981227874755859,0.0024089813232421875,0.00498199462890625,0.016141414642333984,\n",
            "np_naive_ghost,0.0010898113250732422,0.0008485317230224609,0.0014989376068115234,0.002501964569091797,0.0048220157623291016,\n",
            "\n",
            "[Peak RAM in bytes]\n",
            "n,40,80,160,320,640,\n",
            "np_fft,5696.0,11087.0,50924.0,52227.0,56192.0,\n",
            "np_in_place,540.0,540.0,540.0,668.0,668.0,\n",
            "np_naive_fast,100476.0,294076.0,1065276.0,4143676.0,16444476.0,\n",
            "np_naive_ghost,15992.0,31960.0,63320.0,126552.0,251992.0,\n"
          ]
        }
      ]
    }
  ]
}