{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XTu71N5FwVnb"
      },
      "source": [
        "## Install requirements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DMDsjCENv7mW",
        "outputId": "09799337-674a-4240-cb6e-d66f9dea19ce"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting pytorch_lightning\n",
            "  Downloading pytorch_lightning-2.2.4-py3-none-any.whl (802 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m802.2/802.2 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (1.25.2)\n",
            "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (2.2.1+cu121)\n",
            "Requirement already satisfied: tqdm>=4.57.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (4.66.4)\n",
            "Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (6.0.1)\n",
            "Requirement already satisfied: fsspec[http]>=2022.5.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (2023.6.0)\n",
            "Collecting torchmetrics>=0.7.0 (from pytorch_lightning)\n",
            "  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m868.8/868.8 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (24.0)\n",
            "Requirement already satisfied: typing-extensions>=4.4.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_lightning) (4.11.0)\n",
            "Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)\n",
            "  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from fsspec[http]>=2022.5.0->pytorch_lightning) (2.31.0)\n",
            "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]>=2022.5.0->pytorch_lightning) (3.9.5)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->pytorch_lightning) (67.7.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->pytorch_lightning) (3.14.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->pytorch_lightning) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->pytorch_lightning) (3.3)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->pytorch_lightning) (3.1.4)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
            "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "Collecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "Collecting nvidia-nccl-cu12==2.19.3 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
            "Collecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
            "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->pytorch_lightning) (2.2.0)\n",
            "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.13.0->pytorch_lightning)\n",
            "  Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.3.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (23.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.4.1)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (6.0.5)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.9.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (4.0.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->pytorch_lightning) (2.1.5)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]>=2022.5.0->pytorch_lightning) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]>=2022.5.0->pytorch_lightning) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]>=2022.5.0->pytorch_lightning) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->fsspec[http]>=2022.5.0->pytorch_lightning) (2024.2.2)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->pytorch_lightning) (1.3.0)\n",
            "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch_lightning\n",
            "Successfully installed lightning-utilities-0.11.2 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.1.105 pytorch_lightning-2.2.4 torchmetrics-1.4.0.post0\n"
          ]
        }
      ],
      "source": [
        "!pip install pytorch_lightning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GCXtENV8xXJx"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "5t4q01KIxYuJ"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n",
            "/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/ANONYMIZED/miniconda3/envs/ccmm/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c104warnERKNS_7WarningE'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
            "  warn(\n"
          ]
        }
      ],
      "source": [
        "import logging\n",
        "from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union\n",
        "\n",
        "import pytorch_lightning as pl\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchmetrics\n",
        "from torch.optim import Optimizer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fSmkCPy0xOc2"
      },
      "source": [
        "## Reproducibility"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0PyJIw7nxPzp",
        "outputId": "b0f3ee5b-d810-4a3f-b389-b3a002e7eb3e"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Global seed set to 0\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "execution_count": 2,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "pl.seed_everything(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0vM0UnnNwbGE"
      },
      "source": [
        "## Define models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "g0CtAZ6dvNHc"
      },
      "outputs": [],
      "source": [
        "class MLP(nn.Module):\n",
        "    def __init__(self, input=28 * 28, num_classes=10):\n",
        "        super().__init__()\n",
        "        self.input = input\n",
        "        self.layer0 = nn.Linear(input, 512)\n",
        "        self.layer1 = nn.Linear(512, 512)\n",
        "        self.layer2 = nn.Linear(512, 512)\n",
        "        self.layer3 = nn.Linear(512, 256)\n",
        "        self.layer4 = nn.Linear(256, num_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x.view(-1, self.input)\n",
        "\n",
        "        h0 = nn.functional.relu(self.layer0(x))\n",
        "\n",
        "        h1 = nn.functional.relu(self.layer1(h0))\n",
        "\n",
        "        h2 = nn.functional.relu(self.layer2(h1))\n",
        "\n",
        "        h3 = nn.functional.relu(self.layer3(h2))\n",
        "\n",
        "        h4 = self.layer4(h3)\n",
        "\n",
        "        embeddings = [h0, h1, h2, h3, h4]\n",
        "\n",
        "        return nn.functional.log_softmax(h4, dim=-1), embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "DD7fHC05v2UA"
      },
      "outputs": [],
      "source": [
        "\n",
        "class MyLightningModule(pl.LightningModule):\n",
        "\n",
        "    def __init__(self, model, num_classes: int = None, *args, **kwargs) -> None:\n",
        "        super().__init__()\n",
        "\n",
        "        self.num_classes = num_classes\n",
        "\n",
        "        metric = torchmetrics.Accuracy(task='multiclass', top_k=1, num_classes=num_classes)\n",
        "        self.train_accuracy = metric.clone()\n",
        "        self.val_accuracy = metric.clone()\n",
        "        self.test_accuracy = metric.clone()\n",
        "\n",
        "        self.model = model\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"Method for the forward pass.\n",
        "\n",
        "        'training_step', 'validation_step' and 'test_step' should call\n",
        "        this method in order to compute the output predictions and the loss.\n",
        "\n",
        "        Returns:\n",
        "            output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.\n",
        "        \"\"\"\n",
        "        return self.model(x)\n",
        "\n",
        "    def unwrap_batch(self, batch):\n",
        "        if isinstance(batch, Dict):\n",
        "            x, y = batch[\"x\"], batch[\"y\"]\n",
        "        else:\n",
        "            x, y = batch\n",
        "        return x, y\n",
        "\n",
        "    def step(self, x, y) -> Mapping[str, Any]:\n",
        "\n",
        "        output = self(x)\n",
        "        if isinstance(output, tuple):\n",
        "            logits, embeddings = output\n",
        "        else:\n",
        "            logits = output\n",
        "            embeddings = None\n",
        "        loss = F.cross_entropy(logits, y)\n",
        "\n",
        "        return {\"logits\": logits.detach(), \"loss\": loss, \"embeddings\": embeddings}\n",
        "\n",
        "    def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:\n",
        "        x, y = self.unwrap_batch(batch)\n",
        "\n",
        "        step_out = self.step(x, y)\n",
        "        self.log_dict(\n",
        "            {\"loss/train\": step_out[\"loss\"].cpu().detach()},\n",
        "            on_step=True,\n",
        "            on_epoch=True,\n",
        "            prog_bar=True,\n",
        "        )\n",
        "\n",
        "        self.train_accuracy(torch.softmax(step_out[\"logits\"], dim=-1), y)\n",
        "        self.log_dict(\n",
        "            {\n",
        "                \"acc/train\": self.train_accuracy,\n",
        "            },\n",
        "            on_epoch=True,\n",
        "        )\n",
        "\n",
        "        return step_out\n",
        "\n",
        "    def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:\n",
        "        x, y = self.unwrap_batch(batch)\n",
        "\n",
        "        step_out = self.step(x, y)\n",
        "\n",
        "        self.log_dict(\n",
        "            {\"loss/val\": step_out[\"loss\"].cpu().detach()},\n",
        "            on_step=False,\n",
        "            on_epoch=True,\n",
        "            prog_bar=True,\n",
        "        )\n",
        "\n",
        "        self.val_accuracy(torch.softmax(step_out[\"logits\"], dim=-1), y)\n",
        "        self.log_dict(\n",
        "            {\n",
        "                \"acc/val\": self.val_accuracy,\n",
        "            },\n",
        "            on_epoch=True,\n",
        "        )\n",
        "\n",
        "        return step_out\n",
        "\n",
        "    def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:\n",
        "        x, y = self.unwrap_batch(batch)\n",
        "\n",
        "        step_out = self.step(x, y)\n",
        "\n",
        "        self.log_dict(\n",
        "            {\"loss/test\": step_out[\"loss\"].cpu().detach()},\n",
        "        )\n",
        "\n",
        "        self.test_accuracy(torch.softmax(step_out[\"logits\"], dim=-1), y)\n",
        "        self.log_dict(\n",
        "            {\n",
        "                \"acc/test\": self.test_accuracy,\n",
        "            },\n",
        "            on_epoch=True,\n",
        "        )\n",
        "\n",
        "        return step_out\n",
        "\n",
        "    def configure_optimizers(\n",
        "        self,\n",
        "    ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:\n",
        "        \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n",
        "\n",
        "        Normally you'd need one. But in the case of GANs or similar you might have multiple.\n",
        "\n",
        "        Return:\n",
        "            Any of these 6 options.\n",
        "            - Single optimizer.\n",
        "            - List or Tuple - List of optimizers.\n",
        "            - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).\n",
        "            - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'\n",
        "              key whose value is a single LR scheduler or lr_dict.\n",
        "            - Tuple of dictionaries as described, with an optional 'frequency' key.\n",
        "            - None - Fit will run without any optimizer.\n",
        "        \"\"\"\n",
        "        return [torch.optim.Adam(self.parameters(), lr=1e-3)]\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PEGlFftWxzEJ"
      },
      "source": [
        "## Load data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "A_tFuXkvyFR_"
      },
      "outputs": [],
      "source": [
        "from torchvision.datasets import MNIST\n",
        "from torchvision.transforms import Compose, ToTensor, Normalize\n",
        "\n",
        "transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n",
        "train_dataset = MNIST(root='.', train=True, download=True, transform=transform)\n",
        "test_dataset = MNIST(root='.', train=False, download=True, transform=transform)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "r0I-cJ3rx1G_"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_loader = DataLoader(train_dataset, batch_size=1000, num_workers=8)\n",
        "\n",
        "test_loader = DataLoader(test_dataset, batch_size=1000, num_workers=8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K2QBZxY8v3M7"
      },
      "source": [
        "## Train a bunch of models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_OVpiqJHvjBU",
        "outputId": "ecbc91a5-2141-419c-cbc6-c8cf339fed02"
      },
      "outputs": [],
      "source": [
        "num_classes = 10\n",
        "model_a = MyLightningModule(MLP(num_classes), num_classes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 193
        },
        "id": "5H-M4n6Dw1aC",
        "outputId": "56cd1d79-c503-4aea-e51b-dcfe0b7b2256"
      },
      "outputs": [
        {
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[9], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpytorch_lightning\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Trainer\n\u001b[0;32m----> 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43menable_progress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43menable_model_summary\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m      5\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(model_a, train_loader)\n\u001b[1;32m      7\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtest(model_a, test_loader)\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/utilities/argparse.py:345\u001b[0m, in \u001b[0;36m_defaults_from_env_vars.<locals>.insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    342\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m    344\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m--> 345\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:534\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, logger, enable_checkpointing, callbacks, default_root_dir, gradient_clip_val, gradient_clip_algorithm, num_nodes, num_processes, devices, gpus, auto_select_gpus, tpu_cores, ipus, enable_progress_bar, overfit_batches, track_grad_norm, check_val_every_n_epoch, fast_dev_run, accumulate_grad_batches, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, val_check_interval, log_every_n_steps, accelerator, strategy, sync_batchnorm, precision, enable_model_summary, weights_save_path, num_sanity_val_steps, resume_from_checkpoint, profiler, benchmark, deterministic, reload_dataloaders_every_n_epochs, auto_lr_find, replace_sampler_ddp, detect_anomaly, auto_scale_batch_size, plugins, amp_backend, amp_level, move_metrics_to_cpu, multiple_trainloader_mode)\u001b[0m\n\u001b[1;32m    531\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrack_grad_norm: \u001b[38;5;28mfloat\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(track_grad_norm)\n\u001b[1;32m    533\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m detect_anomaly\n\u001b[0;32m--> 534\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_setup_on_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    536\u001b[0m \u001b[38;5;66;03m# configure tuner\u001b[39;00m\n\u001b[1;32m    537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtuner\u001b[38;5;241m.\u001b[39mon_trainer_init(auto_lr_find, auto_scale_batch_size)\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:619\u001b[0m, in \u001b[0;36mTrainer._setup_on_init\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    618\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_setup_on_init\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 619\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_log_device_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    621\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshould_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m    622\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate \u001b[38;5;241m=\u001b[39m TrainerState()\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1740\u001b[0m, in \u001b[0;36mTrainer._log_device_info\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_log_device_info\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1740\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mCUDAAccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_available\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m   1741\u001b[0m         gpu_available \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m   1742\u001b[0m         gpu_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m (cuda)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/accelerators/cuda.py:91\u001b[0m, in \u001b[0;36mCUDAAccelerator.is_available\u001b[0;34m()\u001b[0m\n\u001b[1;32m     89\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m     90\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mis_available\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n\u001b[0;32m---> 91\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdevice_parser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_cuda_devices\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/utilities/device_parser.py:348\u001b[0m, in \u001b[0;36mnum_cuda_devices\u001b[0;34m()\u001b[0m\n\u001b[1;32m    346\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mdevice_count()\n\u001b[1;32m    347\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m multiprocessing\u001b[38;5;241m.\u001b[39mget_context(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfork\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mPool(\u001b[38;5;241m1\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m pool:\n\u001b[0;32m--> 348\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_count\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/multiprocessing/pool.py:357\u001b[0m, in \u001b[0;36mPool.apply\u001b[0;34m(self, func, args, kwds)\u001b[0m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, args\u001b[38;5;241m=\u001b[39m(), kwds\u001b[38;5;241m=\u001b[39m{}):\n\u001b[1;32m    353\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[1;32m    354\u001b[0m \u001b[38;5;124;03m    Equivalent of `func(*args, **kwds)`.\u001b[39;00m\n\u001b[1;32m    355\u001b[0m \u001b[38;5;124;03m    Pool must be running.\u001b[39;00m\n\u001b[1;32m    356\u001b[0m \u001b[38;5;124;03m    '''\u001b[39;00m\n\u001b[0;32m--> 357\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply_async\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/multiprocessing/pool.py:765\u001b[0m, in \u001b[0;36mApplyResult.get\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    764\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 765\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    766\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mready():\n\u001b[1;32m    767\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/multiprocessing/pool.py:762\u001b[0m, in \u001b[0;36mApplyResult.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    761\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwait\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 762\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_event\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/threading.py:581\u001b[0m, in \u001b[0;36mEvent.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    579\u001b[0m signaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flag\n\u001b[1;32m    580\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m signaled:\n\u001b[0;32m--> 581\u001b[0m     signaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cond\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m signaled\n",
            "File \u001b[0;32m~/miniconda3/envs/ccmm/lib/python3.9/threading.py:312\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    310\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:    \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m    311\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 312\u001b[0m         \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    313\u001b[0m         gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m    314\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "from pytorch_lightning import Trainer\n",
        "\n",
        "trainer = Trainer(enable_progress_bar=True, enable_model_summary=False, max_epochs=50)\n",
        "\n",
        "trainer.fit(model_a, train_loader)\n",
        "\n",
        "trainer.test(model_a, test_loader)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TXo6520VxMVI"
      },
      "source": [
        "## Match the models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "id": "ujovajWkvef3"
      },
      "outputs": [],
      "source": [
        "from typing import NamedTuple\n",
        "from collections import defaultdict\n",
        "\n",
        "class PermutationSpec(NamedTuple):\n",
        "    # maps permutation matrices to the layers they permute, expliciting the axis they act on\n",
        "    perm_to_layers_and_axes: dict\n",
        "\n",
        "    # maps layers to permutations: if a layer has k dimensions, it maps to a permutation matrix (or None) for each dimension\n",
        "    layer_and_axes_to_perm: dict\n",
        "\n",
        "\n",
        "class PermutationSpecBuilder:\n",
        "    def __init__(self) -> None:\n",
        "        pass\n",
        "\n",
        "    def create_permutation_spec(self) -> list:\n",
        "        pass\n",
        "\n",
        "    def permutation_spec_from_axes_to_perm(self, axes_to_perm: dict) -> PermutationSpec:\n",
        "        perm_to_axes = defaultdict(list)\n",
        "\n",
        "        for wk, axis_perms in axes_to_perm.items():\n",
        "            for axis, perm in enumerate(axis_perms):\n",
        "                if perm is not None:\n",
        "                    perm_to_axes[perm].append((wk, axis))\n",
        "\n",
        "        return PermutationSpec(perm_to_layers_and_axes=dict(perm_to_axes), layer_and_axes_to_perm=axes_to_perm)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "WoA_QxmKvZ5Y"
      },
      "outputs": [],
      "source": [
        "class MLPPermutationSpecBuilder(PermutationSpecBuilder):\n",
        "    def __init__(self, num_hidden_layers: int):\n",
        "        self.num_hidden_layers = num_hidden_layers\n",
        "\n",
        "    def create_permutation_spec(self) -> PermutationSpec:\n",
        "        L = self.num_hidden_layers\n",
        "        assert L >= 1\n",
        "\n",
        "        axes_to_perm = {\n",
        "            \"layer0.weight\": (\"P_0\", None),\n",
        "            **{f\"layer{i}.weight\": (f\"P_{i}\", f\"P_{i-1}\") for i in range(1, L)},\n",
        "            **{f\"layer{i}.bias\": (f\"P_{i}\",) for i in range(L)},\n",
        "            f\"layer{L}.weight\": (None, f\"P_{L-1}\"),\n",
        "            f\"layer{L}.bias\": (None,),\n",
        "        }\n",
        "\n",
        "        return self.permutation_spec_from_axes_to_perm(axes_to_perm)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "XTu71N5FwVnb",
        "GCXtENV8xXJx",
        "fSmkCPy0xOc2",
        "0vM0UnnNwbGE",
        "K2QBZxY8v3M7"
      ],
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.19"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
