{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "The notebook runs, however, it requires the difflogic library which only works on certain colab instances... The best solution we found is to delete the runtime and start a new one if it fails to install."
      ],
      "metadata": {
        "id": "t42FTZ3ovTCC"
      },
      "id": "t42FTZ3ovTCC"
    },
    {
      "cell_type": "code",
      "source": [
        "#Install DiffLogic and correct CUDA version\n",
        "\n",
        "!sudo apt-get install -y cmake ninja-build\n",
        "!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121\n",
        "!pip install difflogic"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vlm0Q0G5EYF7",
        "outputId": "19d0e10a-abb7-4893-ae22-19be4fb95810"
      },
      "id": "vlm0Q0G5EYF7",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Reading package lists... Done\n",
            "Building dependency tree... Done\n",
            "Reading state information... Done\n",
            "cmake is already the newest version (3.22.1-1ubuntu1.22.04.2).\n",
            "The following NEW packages will be installed:\n",
            "  ninja-build\n",
            "0 upgraded, 1 newly installed, 0 to remove and 34 not upgraded.\n",
            "Need to get 111 kB of archives.\n",
            "After this operation, 358 kB of additional disk space will be used.\n",
            "Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 ninja-build amd64 1.10.1-1 [111 kB]\n",
            "Fetched 111 kB in 1s (154 kB/s)\n",
            "debconf: unable to initialize frontend: Dialog\n",
            "debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 1.)\n",
            "debconf: falling back to frontend: Readline\n",
            "debconf: unable to initialize frontend: Readline\n",
            "debconf: (This frontend requires a controlling tty.)\n",
            "debconf: falling back to frontend: Teletype\n",
            "dpkg-preconfigure: unable to re-open stdin: \n",
            "Selecting previously unselected package ninja-build.\n",
            "(Reading database ... 126102 files and directories currently installed.)\n",
            "Preparing to unpack .../ninja-build_1.10.1-1_amd64.deb ...\n",
            "Unpacking ninja-build (1.10.1-1) ...\n",
            "Setting up ninja-build (1.10.1-1) ...\n",
            "Processing triggers for man-db (2.10.2-1) ...\n",
            "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
            "Collecting torch==2.4.0\n",
            "  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl (799.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m799.1/799.1 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchvision==0.19.0\n",
            "  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.1/7.1 MB\u001b[0m \u001b[31m43.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchaudio==2.4.0\n",
            "  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m51.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.18.0)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (4.13.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (1.13.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.4.0) (2025.3.2)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m50.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m22.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch==2.4.0)\n",
            "  Downloading https://download.pytorch.org/whl/cu121/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "c6a2e435",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 387
        },
        "id": "c6a2e435",
        "outputId": "187b4b30-d1ec-4b3e-c4c8-2c61b47d8129"
      },
      "outputs": [
        {
          "output_type": "error",
          "ename": "ModuleNotFoundError",
          "evalue": "No module named 'difflogic'",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-2-bb551264eb82>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdifflogic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLogicLayer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mGroupSum\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdifflogic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpackbitstensor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPackBitsTensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdifflogic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdifflogic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLogicLayerCudaFunction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'difflogic'",
            "",
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
          ],
          "errorDetails": {
            "actions": [
              {
                "action": "open_url",
                "actionText": "Open Examples",
                "url": "/notebooks/snippets/importing_libraries.ipynb"
              }
            ]
          }
        }
      ],
      "source": [
        "import copy\n",
        "import time\n",
        "from types import MethodType\n",
        "from typing import List, Tuple, Callable\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "import torchvision\n",
        "from tqdm import tqdm\n",
        "\n",
        "from difflogic import LogicLayer, GroupSum\n",
        "from difflogic.packbitstensor import PackBitsTensor\n",
        "from difflogic.difflogic import LogicLayerCudaFunction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "428a2e50",
      "metadata": {
        "id": "428a2e50"
      },
      "outputs": [],
      "source": [
        "seed = 42\n",
        "if seed is not None:\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    print(f\"[INFO] Using fixed seed: {seed}\")\n",
        "\n",
        "\n",
        "BITS_TO_TORCH_FLOATING_POINT_TYPE = {\n",
        "    16: torch.float16,\n",
        "    32: torch.float32,\n",
        "    64: torch.float64\n",
        "}\n",
        "training_bit_count =  32\n",
        "\n",
        "\n",
        "def bin_op(a, b, i):\n",
        "    assert a[0].shape == b[0].shape, (a[0].shape, b[0].shape)\n",
        "    if a.shape[0] > 1:\n",
        "        assert a[1].shape == b[1].shape, (a[1].shape, b[1].shape)\n",
        "\n",
        "    if i == 0:\n",
        "        return torch.zeros_like(a)\n",
        "    elif i == 1:\n",
        "        return a * b\n",
        "    elif i == 2:\n",
        "        return a - a * b\n",
        "    elif i == 3:\n",
        "        return a\n",
        "    elif i == 4:\n",
        "        return b - a * b\n",
        "    elif i == 5:\n",
        "        return b\n",
        "    elif i == 6:\n",
        "        return a + b - 2 * a * b\n",
        "    elif i == 7:\n",
        "        return a + b - a * b\n",
        "    elif i == 8:\n",
        "        return 1 - (a + b - a * b)\n",
        "    elif i == 9:\n",
        "        return 1 - (a + b - 2 * a * b)\n",
        "    elif i == 10:\n",
        "        return 1 - b\n",
        "    elif i == 11:\n",
        "        return 1 - b + a * b\n",
        "    elif i == 12:\n",
        "        return 1 - a\n",
        "    elif i == 13:\n",
        "        return 1 - a + a * b\n",
        "    elif i == 14:\n",
        "        return 1 - a * b\n",
        "    elif i == 15:\n",
        "        return torch.ones_like(a)\n",
        "\n",
        "\n",
        "def bin_op_s(a, b, i_s):\n",
        "    r = torch.zeros_like(a)\n",
        "    for i in range(16):\n",
        "        u = bin_op(a, b, i)\n",
        "        r = r + i_s[..., i] * u\n",
        "    return r\n",
        "\n",
        "\n",
        "def load_n(loader, n):\n",
        "    i = 0\n",
        "    while i < n:\n",
        "        for x in loader:\n",
        "            yield x\n",
        "            i += 1\n",
        "            if i == n:\n",
        "                break"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1f1e5f29",
      "metadata": {
        "id": "1f1e5f29"
      },
      "outputs": [],
      "source": [
        "def train_step(model, x, y, loss_fn, optimizer):\n",
        "    out = model(x)\n",
        "    loss = loss_fn(out, y)\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    return loss.item()\n",
        "\n",
        "\n",
        "def train(model, train_loader, test_loader, loss_fn, optimizer, n_steps, print_freq, full_eval: bool) -> dict[str, list]:\n",
        "    loss_values = 0\n",
        "    results = {'training':[]}\n",
        "    for i, (x, y) in tqdm(enumerate(load_n(train_loader, n_steps)), total=n_steps):\n",
        "        x = x.to(BITS_TO_TORCH_FLOATING_POINT_TYPE[training_bit_count]).to('cuda')\n",
        "        y = y.to('cuda')\n",
        "        loss = train_step(model, x, y, loss_fn, optimizer)\n",
        "        loss_values += loss\n",
        "        if i%print_freq == (print_freq - 1):\n",
        "            if full_eval:\n",
        "                train_soft = eval_model(model, train_loader, train_mode=True)\n",
        "                train_discrete = eval_model(model, train_loader, train_mode=False)\n",
        "                test_soft = eval_model(model, test_loader, train_mode=True)\n",
        "                test_discrete = eval_model(model, test_loader, train_mode=False)\n",
        "            else:\n",
        "                train_soft = -1\n",
        "                train_discrete = -1\n",
        "                test_soft = -1\n",
        "                test_discrete = -1\n",
        "            print(f\"Step {i+1}/{n_steps} - Loss: {loss_values/print_freq:7.4f} - \"\n",
        "                  f\"Train soft: {train_soft:4.2%} - \"\n",
        "                  f\"Train discrete: {train_discrete:4.2%} - \"\n",
        "                  f\"Test soft: {test_soft:4.2%} - \"\n",
        "                  f\"Test discrete: {test_discrete:4.2%}\")\n",
        "            results['training'].append({\n",
        "                'step': i,\n",
        "                'loss': loss_values/print_freq,\n",
        "                'time': time.time(),\n",
        "                'train_soft': train_soft,\n",
        "                'train_discrete': train_discrete,\n",
        "                'test_soft': test_soft,\n",
        "                'test_discrete': test_discrete\n",
        "            })\n",
        "            print()\n",
        "            loss_values = 0\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def eval_model(model, loader, train_mode):\n",
        "    orig = model.training\n",
        "    with torch.no_grad():\n",
        "        model.train(mode=train_mode)\n",
        "        accs = []\n",
        "        for x, y in loader:\n",
        "            pred = model(x.to('cuda').round()).argmax(-1)\n",
        "            accs.append((pred == y.to('cuda')).float().mean().item())\n",
        "        res = float(np.mean(accs))\n",
        "    model.train(mode=orig)\n",
        "    return res\n",
        "\n",
        "\n",
        "def patch_logic_layer(model, default_tau=1.0, verbose=True):\n",
        "    \"\"\"\n",
        "    Monkey-patch every LogicLayer in `model` so it uses Gumbel-Softmax\n",
        "    and prints a confirmation the first time its forward is executed.\n",
        "    \"\"\"\n",
        "\n",
        "    def _weights_to_ops(self, training):\n",
        "        if training:\n",
        "            return F.gumbel_softmax(\n",
        "                self.weights,\n",
        "                tau=getattr(self, \"gumbel_tau\", default_tau),\n",
        "                hard=True,\n",
        "                dim=-1,\n",
        "            )\n",
        "        else:\n",
        "            return F.one_hot(self.weights.argmax(-1), 16).to(torch.float32)\n",
        "\n",
        "    # ------------- python path -------------\n",
        "    def forward_python_gumbel(self, x):\n",
        "        if verbose and not getattr(self, \"_print_done\", False):\n",
        "            print(f\"[GUMBEL] LogicLayer(id={id(self)}) using forward_python_gumbel\")\n",
        "            self._print_done = True\n",
        "\n",
        "        assert x.shape[-1] == self.in_dim\n",
        "        if self.indices[0].dtype != torch.long:\n",
        "            self.indices = self.indices[0].long(), self.indices[1].long()\n",
        "\n",
        "        a, b = x[..., self.indices[0]], x[..., self.indices[1]]\n",
        "        weights = _weights_to_ops(self, self.training)\n",
        "        return bin_op_s(a, b, weights)\n",
        "\n",
        "    # ------------- CUDA path ---------------\n",
        "    def forward_cuda_gumbel(self, x):\n",
        "        if verbose and not getattr(self, \"_print_done\", False):\n",
        "            print(f\"[GUMBEL] LogicLayer(id={id(self)}) using forward_cuda_gumbel\")\n",
        "            self._print_done = True\n",
        "\n",
        "        assert x.ndim == 2\n",
        "        assert x.device.type == \"cuda\", x.device\n",
        "        x = x.transpose(0, 1).contiguous()\n",
        "\n",
        "        a, b = self.indices\n",
        "        w = _weights_to_ops(self, self.training).to(x.dtype)\n",
        "\n",
        "        return LogicLayerCudaFunction.apply(\n",
        "            x, a, b, w,\n",
        "            self.given_x_indices_of_y_start,\n",
        "            self.given_x_indices_of_y\n",
        "        ).transpose(0, 1)\n",
        "\n",
        "    # ------------- master forward ----------\n",
        "    def forward_gumbel(self, x):\n",
        "        # print once here too if you want:\n",
        "        if verbose and not getattr(self, \"_print_done\", False):\n",
        "            print(f\"[GUMBEL] LogicLayer(id={id(self)}) master forward\")\n",
        "            # don't set _print_done here—let sub-forward do it\n",
        "\n",
        "        if self.implementation == \"cuda\":\n",
        "            if isinstance(x, PackBitsTensor):\n",
        "                return self.forward_cuda_eval(x)\n",
        "            return self.forward_cuda(x)\n",
        "        elif self.implementation == \"python\":\n",
        "            return self.forward_python(x)\n",
        "        else:\n",
        "            raise ValueError(self.implementation)\n",
        "\n",
        "    # -------- apply to every layer ---------\n",
        "    for layer in model.modules():\n",
        "        if isinstance(layer, LogicLayer):\n",
        "            layer.gumbel_tau = default_tau\n",
        "            layer.forward_python = MethodType(forward_python_gumbel, layer)\n",
        "            layer.forward_cuda   = MethodType(forward_cuda_gumbel,   layer)\n",
        "            layer.forward        = MethodType(forward_gumbel,        layer)\n",
        "\n",
        "\n",
        "def get_dataloaders() -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, int]:\n",
        "    \"\"\"\n",
        "    Returns the CIFAR-10 dataset and the corresponding DataLoaders.\n",
        "    \"\"\"\n",
        "    #Load CIFAR-10 dataset\n",
        "    binarize = lambda x: torch.cat([(x > (i+1)/32).float() for i in range(31)], dim=0)\n",
        "    transforms = torchvision.transforms.Compose([\n",
        "        torchvision.transforms.ToTensor(),\n",
        "        torchvision.transforms.Lambda(binarize)\n",
        "    ])\n",
        "    train_set_cifar = torchvision.datasets.CIFAR10('./data/cifar', train=True, download=True, transform=transforms)\n",
        "    test_set_cifar  = torchvision.datasets.CIFAR10('./data/cifar', train=False, download=True, transform=transforms)\n",
        "\n",
        "    in_dim_cifar    = 3 * 32 * 32 * 31 # Size of CIFAR-10\n",
        "\n",
        "    train_set = train_set_cifar\n",
        "    test_set  = test_set_cifar\n",
        "    in_dim = in_dim_cifar\n",
        "\n",
        "    # DataLoaders\n",
        "    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,\n",
        "                                                shuffle=True, pin_memory=True,\n",
        "                                                drop_last=True, num_workers=4)\n",
        "    test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=1024,\n",
        "                                                shuffle=False, pin_memory=True,\n",
        "                                                drop_last=False, num_workers=2)\n",
        "    return train_loader, test_loader, in_dim\n",
        "\n",
        "\n",
        "def get_model(in_dim: int, width: int=256_000, depth: int=12, gumbel_tau: float=0.2,\n",
        "             group_sum_k: int=10, group_sum_tau: float=30, gumbel_model: bool=False):\n",
        "    \"\"\"\n",
        "    Returns a model with the specified parameters.\n",
        "    \"\"\"\n",
        "    model = torch.nn.Sequential(\n",
        "        torch.nn.Flatten(),\n",
        "        LogicLayer(in_dim, width),  # 1\n",
        "        *[LogicLayer(width, width) for _ in range(depth-1)], # 2 to depth\n",
        "        GroupSum(k=group_sum_k, tau=group_sum_tau)\n",
        "    )\n",
        "    model = model.to('cuda')\n",
        "    if gumbel_model:\n",
        "        patch_logic_layer(model, default_tau=gumbel_tau, verbose=False)\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "48b8d4da",
      "metadata": {
        "id": "48b8d4da"
      },
      "outputs": [],
      "source": [
        "# Empirically, we observe that the free version of Google Collab is able to run ~4.5 iterations / second. For both Differentiable LGNs and Gumbel LGNs, we allow a compute budget of 2 hours, which equates to 4.5 x 2 x 60 x 60 = 32400 iterations.\n",
        "def main(gumbel_model: bool, model_depth: int):\n",
        "\n",
        "    n_steps = 3_000\n",
        "    print_freq = 1_200\n",
        "    full_eval=True\n",
        "    loss_fn   = torch.nn.CrossEntropyLoss()\n",
        "\n",
        "\n",
        "    # Load CIFAR-10 dataset\n",
        "    train_loader, test_loader, in_dim = get_dataloaders()\n",
        "    # Define model\n",
        "    model = get_model(in_dim, width=256_000, depth=model_depth, gumbel_tau=1.00, group_sum_k=10, group_sum_tau=30, gumbel_model=gumbel_model)\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr = 0.1 if gumbel_model else 0.1)\n",
        "\n",
        "    results = train(model, train_loader, test_loader, loss_fn, optimizer, n_steps=n_steps, print_freq=print_freq, full_eval=full_eval)\n",
        "\n",
        "    results[\"final\"] = {\n",
        "        \"train_soft\": eval_model(model, train_loader, train_mode=True),\n",
        "        \"train_discrete\": eval_model(model, train_loader, train_mode=False),\n",
        "        \"test_soft\": eval_model(model, test_loader, train_mode=True),\n",
        "        \"test_discrete\": eval_model(model, test_loader, train_mode=False)\n",
        "    }\n",
        "    results[\"model\"] = model\n",
        "    print(\"Final evaluation:\")\n",
        "    print(\"soft:    \", results[\"final\"][\"train_soft\"])\n",
        "    print(\"discrete:\", results[\"final\"][\"train_discrete\"])\n",
        "    print(\"soft:    \", results[\"final\"][\"test_soft\"])\n",
        "    print(\"discrete:\", results[\"final\"][\"test_discrete\"])\n",
        "\n",
        "    return results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0d7bdf58",
      "metadata": {
        "id": "0d7bdf58"
      },
      "outputs": [],
      "source": [
        "gumbel_model = True\n",
        "model_depth = 12\n",
        "\n",
        "start_time = time.time()\n",
        "gumbel_results = main(gumbel_model=gumbel_model, model_depth=model_depth)\n",
        "end_time = time.time()\n",
        "print(f\"Time taken: {end_time - start_time} seconds\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8fec79a0",
      "metadata": {
        "id": "8fec79a0"
      },
      "outputs": [],
      "source": [
        "gumbel_model = False\n",
        "model_depth = 12\n",
        "\n",
        "start_time = time.time()\n",
        "softmax_results = main(gumbel_model=gumbel_model, model_depth=model_depth)\n",
        "end_time = time.time()\n",
        "print(f\"Time taken: {end_time - start_time} seconds\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d2813001",
      "metadata": {
        "id": "d2813001"
      },
      "outputs": [],
      "source": [
        "# Put the results in a dataframe and show the table\n",
        "import pandas as pd\n",
        "gumbel_df = pd.Series(gumbel_results['final'], name='Gumbel')\n",
        "softmax_df = pd.Series(softmax_results['final'], name='Softmax')\n",
        "\n",
        "# Create a dataframe from the series\n",
        "df = pd.DataFrame({\n",
        "    'Gumbel': gumbel_df,\n",
        "    'Softmax': softmax_df\n",
        "})\n",
        "\n",
        "# Compute train and test dicretization gap for both models\n",
        "# df.loc[\"Gumbel Train Discrete Gap\"] = df.loc[\"Gumbel\", \"train_soft\"] - df.loc[\"Gumbel\", \"train_discrete\"]\n",
        "# train discretization gap = train soft - train discrete\n",
        "df.loc[\"train_discretization_gap\"] = df.loc[\"train_soft\"] - df.loc[\"train_discrete\"]\n",
        "df.loc[\"test_discretization_gap\"] = df.loc[\"test_soft\"] - df.loc[\"test_discrete\"]\n",
        "# Sort index\n",
        "df = df.loc[[\"train_soft\", \"train_discrete\", \"train_discretization_gap\", \"test_soft\", \"test_discrete\", \"test_discretization_gap\"]]\n",
        "\n",
        "# Remove \"_\" from the index and capitalize the first letter\n",
        "df.index = df.index.str.replace(\"_\", \" \").str.capitalize()\n",
        "\n",
        "# Display the dataframe\n",
        "print(df)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0eb8bc76",
      "metadata": {
        "id": "0eb8bc76"
      },
      "source": [
        "We see much lower (even negative) discretization gaps and also better performance in the discrete setting for our gumbel method."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c557a9d1",
      "metadata": {
        "id": "c557a9d1"
      },
      "source": [
        "## Create plots of the loss landscapes"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9be6a601",
      "metadata": {
        "id": "9be6a601"
      },
      "source": [
        "We provide the following code for reference. However, the models have not converged within the 3k iterations used above to be within a fast runtime. For this, more than 50k iterations are needed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "880bb5a2",
      "metadata": {
        "id": "880bb5a2"
      },
      "outputs": [],
      "source": [
        "\n",
        "def get_random_direction(model: torch.nn.Module) -> List[torch.Tensor]:\n",
        "    \"\"\"\n",
        "    Sample a random direction in the parameter space of the model.\n",
        "\n",
        "    Args:\n",
        "        model: PyTorch model\n",
        "\n",
        "    Returns:\n",
        "        List of tensors representing a normalized random direction\n",
        "    \"\"\"\n",
        "    direction = []\n",
        "    for param in model.parameters():\n",
        "        if param.requires_grad:\n",
        "            direction.append(torch.randn_like(param))\n",
        "\n",
        "    # Normalize the direction\n",
        "    norm = torch.sqrt(sum(torch.sum(d * d) for d in direction))\n",
        "    for i in range(len(direction)):\n",
        "        direction[i] = direction[i] / norm\n",
        "\n",
        "    return direction\n",
        "\n",
        "def add_direction_to_model(model: torch.nn.Module, direction: List[torch.Tensor], scale: float) -> None:\n",
        "    \"\"\"\n",
        "    Add a scaled direction to the model parameters.\n",
        "\n",
        "    Args:\n",
        "        model: PyTorch model\n",
        "        direction: Direction in parameter space\n",
        "        scale: Scaling factor\n",
        "    \"\"\"\n",
        "    for param, d in zip([p for p in model.parameters() if p.requires_grad], direction):\n",
        "        param.data.add_(scale * d)\n",
        "\n",
        "def compute_loss_landscape(\n",
        "    model: torch.nn.Module,\n",
        "    train_loader: DataLoader,\n",
        "    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
        "    alphas: np.ndarray,\n",
        "    betas: np.ndarray,\n",
        "    device: torch.device\n",
        ") -> Tuple[np.ndarray, List[torch.Tensor], List[torch.Tensor]]:\n",
        "    \"\"\"\n",
        "    Compute the loss landscape by sampling in two random directions.\n",
        "\n",
        "    Args:\n",
        "        model: PyTorch model\n",
        "        train_loader: DataLoader for training data\n",
        "        loss_fn: Loss function\n",
        "        alphas: List of scaling factors for first direction\n",
        "        betas: List of scaling factors for second direction\n",
        "        device: Device to use for computation\n",
        "\n",
        "    Returns:\n",
        "        Tuple containing loss landscape matrix and the two random directions\n",
        "    \"\"\"\n",
        "    # Store original parameters\n",
        "    original_params = copy.deepcopy(list(p.data for p in model.parameters() if p.requires_grad))\n",
        "\n",
        "    # Sample two random directions\n",
        "    direction1 = get_random_direction(model)\n",
        "    direction2 = get_random_direction(model)\n",
        "\n",
        "    # Get two batches from the train loader\n",
        "    train_iter = iter(train_loader)\n",
        "    try:\n",
        "        batch1 = next(train_iter)\n",
        "        batch2 = next(train_iter)\n",
        "    except StopIteration:\n",
        "        train_iter = iter(train_loader)\n",
        "        batch1 = next(train_iter)\n",
        "        batch2 = next(train_iter)\n",
        "\n",
        "    # Move batches to device\n",
        "    inputs1, targets1 = batch1\n",
        "    inputs2, targets2 = batch2\n",
        "    inputs1, targets1 = inputs1.to(device), targets1.to(device)\n",
        "    inputs2, targets2 = inputs2.to(device), targets2.to(device)\n",
        "\n",
        "    # Compute loss landscape\n",
        "    loss_landscape = np.zeros((len(alphas), len(betas)))\n",
        "\n",
        "    for i, alpha in tqdm(enumerate(alphas), total=len(alphas), desc=\"Computing loss landscape\"):\n",
        "        for j, beta in enumerate(betas):\n",
        "            # Reset model parameters to original values\n",
        "            for param, orig in zip([p for p in model.parameters() if p.requires_grad], original_params):\n",
        "                param.data.copy_(orig)\n",
        "\n",
        "            # Add scaled directions to model parameters\n",
        "            add_direction_to_model(model, direction1, alpha)\n",
        "            add_direction_to_model(model, direction2, beta)\n",
        "\n",
        "            # Compute loss\n",
        "            model.train()\n",
        "            with torch.no_grad():\n",
        "                outputs1 = model(inputs1)\n",
        "                loss1 = loss_fn(outputs1, targets1)\n",
        "\n",
        "                outputs2 = model(inputs2)\n",
        "                loss2 = loss_fn(outputs2, targets2)\n",
        "\n",
        "                # Average loss over two batches\n",
        "                loss = (loss1 + loss2) / 2\n",
        "\n",
        "                loss_landscape[i, j] = loss.item()\n",
        "\n",
        "    # Reset model parameters to original values\n",
        "    for param, orig in zip([p for p in model.parameters() if p.requires_grad], original_params):\n",
        "        param.data.copy_(orig)\n",
        "\n",
        "    return loss_landscape, direction1, direction2\n",
        "\n",
        "def plot_loss_landscape(\n",
        "    loss_landscape: np.ndarray,\n",
        "    alphas: np.ndarray,\n",
        "    betas: np.ndarray,\n",
        "    title: str = \"Loss Landscape\"\n",
        ") -> plt.Figure:\n",
        "    \"\"\"\n",
        "    Plot the loss landscape.\n",
        "\n",
        "    Args:\n",
        "        loss_landscape: Loss landscape matrix\n",
        "        alphas: List of scaling factors for first direction\n",
        "        betas: List of scaling factors for second direction\n",
        "        title: Plot title\n",
        "\n",
        "    Returns:\n",
        "        Matplotlib figure\n",
        "    \"\"\"\n",
        "    fig = plt.figure(figsize=(10, 8))\n",
        "    ax = fig.add_subplot(111, projection='3d')\n",
        "\n",
        "    alpha_grid, beta_grid = np.meshgrid(alphas, betas)\n",
        "\n",
        "    # Plot the surface\n",
        "    surf = ax.plot_surface(\n",
        "        alpha_grid, beta_grid, loss_landscape.T,\n",
        "        cmap=plt.cm.viridis,\n",
        "        linewidth=0,\n",
        "        antialiased=True\n",
        "    )\n",
        "\n",
        "    # Add labels and title\n",
        "    ax.set_xlabel('Direction 1')\n",
        "    ax.set_ylabel('Direction 2')\n",
        "    # ax.set_zlabel('Loss')\n",
        "    ax.set_title(title)\n",
        "\n",
        "    # Remove zticks\n",
        "    ax.set_zticks([])\n",
        "    ax.set_zticklabels([])\n",
        "\n",
        "    # Add a color bar\n",
        "    cbar = fig.colorbar(surf, ax=ax, shrink=0.75, aspect=15)\n",
        "    # cbar.set_label('Loss')\n",
        "    cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), rotation=0)\n",
        "    cbar.ax.yaxis.set_label_position('left')\n",
        "    cbar.ax.yaxis.set_label_coords(-0.1, 0.5)\n",
        "    cbar.ax.yaxis.set_label_text('Loss')\n",
        "\n",
        "    return fig\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "30bde983",
      "metadata": {
        "id": "30bde983"
      },
      "outputs": [],
      "source": [
        "\"\"\"\n",
        "Example usage of the loss landscape computation.\n",
        "\"\"\"\n",
        "# Set device\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "loss_fn   = torch.nn.CrossEntropyLoss()\n",
        "\n",
        "limit = 1.0\n",
        "samples = 10\n",
        "\n",
        "# Define range of scaling factors\n",
        "alphas = np.linspace(-limit, limit, samples)\n",
        "betas = np.linspace(-limit, limit, samples)\n",
        "\n",
        "# Replace with your actual train_loader\n",
        "# train_loader = ...\n",
        "\n",
        "train_loader, test_loader, in_dim = get_dataloaders()\n",
        "\n",
        "# Compute loss landscape\n",
        "gumbel_loss_landscape, gumbel_direction1, gumbel_direction2 = compute_loss_landscape(\n",
        "    model=gumbel_results[\"model\"],\n",
        "    train_loader=train_loader,\n",
        "    loss_fn=loss_fn,\n",
        "    alphas=alphas,\n",
        "    betas=betas,\n",
        "    device=device\n",
        ")\n",
        "\n",
        "softmax_loss_landscape, softmax_direction1, softmax_direction2 = compute_loss_landscape(\n",
        "    model=softmax_results[\"model\"],\n",
        "    train_loader=train_loader,\n",
        "    loss_fn=loss_fn,\n",
        "    alphas=alphas,\n",
        "    betas=betas,\n",
        "    device=device\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "39c71fcd",
      "metadata": {
        "id": "39c71fcd"
      },
      "outputs": [],
      "source": [
        "# Plot loss landscape\n",
        "gumbel_fig = plot_loss_landscape(\n",
        "    loss_landscape=gumbel_loss_landscape,\n",
        "    alphas=alphas,\n",
        "    betas=betas,\n",
        "    title=\"Gumbel Loss Landscape\"\n",
        ")\n",
        "softmax_fig = plot_loss_landscape(\n",
        "    loss_landscape=softmax_loss_landscape,\n",
        "    alphas=alphas,\n",
        "    betas=betas,\n",
        "    title=\"Softmax Loss Landscape\"\n",
        ")\n",
        "# Show the plot\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7308e535",
      "metadata": {
        "id": "7308e535"
      },
      "source": [
        "As stated earlier, the above plots of the loss landscapes are not representative, as the models have not been trained until convergence."
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.9"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}