{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "example.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "otnqBdyaYBrf"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/neurips2022sub/ntk_activations/blob/main/example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Examples of using new nonlinearities from [`ntk_activations`](https://github.com/neurips2022sub/ntk_activations)"
      ],
      "metadata": {
        "id": "WUJKQgtPlpyp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install git+https://github.com/neurips2022sub/ntk_activations.git"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rmJ6T0nFlcwA",
        "outputId": "6ebe6ae9-439d-4618-e99f-f61e0af8de9c"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://github.com/neurips2022sub/ntk_activations.git\n",
            "  Cloning https://github.com/neurips2022sub/ntk_activations.git to /tmp/pip-req-build-hz0_m17q\n",
            "  Running command git clone -q https://github.com/neurips2022sub/ntk_activations.git /tmp/pip-req-build-hz0_m17q\n",
            "Collecting neural-tangents>=0.5.0\n",
            "  Downloading neural_tangents-0.5.0-py2.py3-none-any.whl (193 kB)\n",
            "\u001b[K     |████████████████████████████████| 193 kB 4.6 MB/s \n",
            "\u001b[?25hRequirement already satisfied: typing-extensions>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from neural-tangents>=0.5.0->ntk-activations==0.0.1) (4.2.0)\n",
            "Collecting frozendict>=2.3\n",
            "  Downloading frozendict-2.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (99 kB)\n",
            "\u001b[K     |████████████████████████████████| 99 kB 1.5 MB/s \n",
            "\u001b[?25hRequirement already satisfied: jax>=0.3 in /usr/local/lib/python3.7/dist-packages (from neural-tangents>=0.5.0->ntk-activations==0.0.1) (0.3.8)\n",
            "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->neural-tangents>=0.5.0->ntk-activations==0.0.1) (1.4.1)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->neural-tangents>=0.5.0->ntk-activations==0.0.1) (1.0.0)\n",
            "Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->neural-tangents>=0.5.0->ntk-activations==0.0.1) (1.21.6)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->neural-tangents>=0.5.0->ntk-activations==0.0.1) (3.3.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.3->neural-tangents>=0.5.0->ntk-activations==0.0.1) (1.15.0)\n",
            "Building wheels for collected packages: ntk-activations\n",
            "  Building wheel for ntk-activations (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for ntk-activations: filename=ntk_activations-0.0.1-py3-none-any.whl size=12912 sha256=8d3424a6173b5b2f9d7a30dff518300fca97c0bbfaeb1203e45f28e762b0f3c0\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-8g93oen8/wheels/b5/6c/15/4329dce81d43ef4f29b84fcd6d52a9a0ffab67c4f133dc1419\n",
            "Successfully built ntk-activations\n",
            "Installing collected packages: frozendict, neural-tangents, ntk-activations\n",
            "Successfully installed frozendict-2.3.2 neural-tangents-0.5.0 ntk-activations-0.0.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "2-Y93-C7lPOC"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "from jax import numpy as np, random\n",
        "from neural_tangents import stax\n",
        "from ntk_activations import stax_extensions"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1. Using `ntk_activations.stax_extensions` with `neural_tangents.stax`"
      ],
      "metadata": {
        "id": "5DIW_x8Yl4Oz"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "You can seamlessly combine layers from `neural_tangents.stax` and `ntk_activations.stax_extensions`."
      ],
      "metadata": {
        "id": "-7g1Y4eRmMj1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "key1, key2, key_init = random.split(random.PRNGKey(1), 3)\n",
        "x1 = random.normal(key1, (3, 2))\n",
        "x2 = random.normal(key2, (4, 2))\n",
        "\n",
        "init_fn, apply_fn, kernel_fn = stax.serial(\n",
        "    stax.Dense(128),\n",
        "    stax_extensions.Gaussian(),\n",
        "    stax.Dense(10)\n",
        ")\n",
        "\n",
        "_, params = init_fn(key_init, x1.shape)\n",
        "outputs = apply_fn(params, x1)\n",
        "kernel = kernel_fn(x1, x2)\n",
        "print(kernel)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zILtW4RsmB7k",
        "outputId": "0402e43e-29f7-4b2d-9285-de587c8ff757"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Kernel(nngp=DeviceArray([[0.4366224 , 0.45867887, 0.38753727, 0.52652365],\n",
            "             [0.4365696 , 0.4073899 , 0.27690652, 0.38512793],\n",
            "             [0.35802934, 0.44013247, 0.29739842, 0.44431883]],            dtype=float32), ntk=DeviceArray([[0.47899756, 0.54227614, 0.6494441 , 0.7565323 ],\n",
            "             [0.6249889 , 0.49593154, 0.3092413 , 0.38644305],\n",
            "             [0.3599234 , 0.6529554 , 0.39999655, 0.61957943]],            dtype=float32), cov1=DeviceArray([0.48287258, 0.41158044, 0.4013612 ], dtype=float32), cov2=DeviceArray([0.54617095, 0.5546928 , 0.3666358 , 0.5852136 ], dtype=float32), x1_is_x2=False, is_gaussian=True, is_reversed=False, is_input=False, diagonal_batch=True, diagonal_spatial=False, shape1=(3, 10), shape2=(4, 10), batch_axis=0, channel_axis=1, mask1=None, mask2=None)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2. Using `Elementwise` to automatically derive the NTK in closed form from the NNGP."
      ],
      "metadata": {
        "id": "z0QxD2fTmp0x"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "`ntk_activations.stax_extensions.Elementwise` derives under the hood the NTK function from the NNGP function using autodiff."
      ],
      "metadata": {
        "id": "3UDv9K_Nm5cs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Hand-derived NTK expression for the sine nonlinearity.\n",
        "_, _, kernel_fn_manual = stax.serial(stax.Dense(1),\n",
        "                                     stax_extensions.Sin())\n",
        "\n",
        "# NNGP function for the sine nonlinearity:\n",
        "def nngp_fn(cov12, var1, var2):\n",
        "  sum_ = (var1 + var2)\n",
        "  s1 = np.exp((-0.5 * sum_ + cov12))\n",
        "  s2 = np.exp((-0.5 * sum_ - cov12))\n",
        "  return (s1 - s2) / 2\n",
        "\n",
        "# Let the `Elementwise` derive the NTK function in closed form automatically.\n",
        "_, _, kernel_fn = stax.serial(stax.Dense(1),\n",
        "                              stax_extensions.Elementwise(nngp_fn=nngp_fn))\n",
        "\n",
        "\n",
        "k_auto = kernel_fn(x1, x2, 'ntk')\n",
        "k_manual = kernel_fn_manual(x1, x2, 'ntk')\n",
        "\n",
        "# The two kernels match!\n",
        "print(np.max(np.abs(k_manual - k_auto)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jRVet8A0nOHO",
        "outputId": "7a9ff1ce-93f8-4987-fb93-3a49fff3399e"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ntk_activations/stax_extensions.py:602: UserWarning: Using JAX autodiff to compute the `fn` derivative for NTK. Beware of https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.\n",
            "  'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3. Using `ElementwiseNumerical` to approximate kernels given only the nonlinearity."
      ],
      "metadata": {
        "id": "AEyUoK_Un8y3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "`ntk_activations.stax_extensions.ElementwiseNumerical` approximates the NNGP and NTK using Gaussian quadrature and autodiff."
      ],
      "metadata": {
        "id": "vBL3s42Dn8y4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# A nonlinearity with a known closed-form expression (GeLU).\n",
        "_, _, kernel_fn_closed_form = stax.serial(\n",
        "  stax.Dense(1),\n",
        "  stax_extensions.Gelu(),  # Contains the closed-form GeLU NNGP/NTK expression.\n",
        "  stax.Dense(1)\n",
        ")\n",
        "kernel_closed_form = kernel_fn_closed_form(x1, x2)\n",
        "\n",
        "# Construct the layer from only the elementwise forward-pass GeLU.\n",
        "_, _, kernel_fn_numerical = stax.serial(\n",
        "  stax.Dense(1),\n",
        "  stax.ElementwiseNumerical(jax.nn.gelu, deg=25),  # quadrature and autodiff.\n",
        "  stax.Dense(1)\n",
        ")\n",
        "kernel_numerical = kernel_fn_numerical(x1, x2)\n",
        "\n",
        "# The two kernels are close!\n",
        "print(np.max(np.abs(kernel_closed_form.nngp - kernel_numerical.nngp)))\n",
        "print(np.max(np.abs(kernel_closed_form.ntk - kernel_numerical.ntk)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "symjuXs9ogsg",
        "outputId": "15f2f626-4a49-4242-bda7-51e09e372d58"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/neural_tangents/_src/stax/elementwise.py:804: UserWarning: Numerical Activation Layer with fn=<function gelu at 0x7f6e7c60bc20>, deg=25 used!Note that numerical error is controlled by `deg` and for a given tolerance level, required `deg` will highly be dependent on the choice of `fn`.\n",
            "  f'Numerical Activation Layer with fn={fn}, deg={deg} used!'\n",
            "/usr/local/lib/python3.7/dist-packages/neural_tangents/_src/stax/elementwise.py:813: UserWarning: Using JAX autodiff to compute the `fn` derivative for NTK. Beware of https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.\n",
            "  'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "3.823638e-05\n",
            "8.529425e-05\n"
          ]
        }
      ]
    }
  ]
}