{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f009a926",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-16 18:41:19.412211: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
     ]
    }
   ],
   "source": [
    "from jax import random\n",
    "import jax\n",
    "import neural_tangents as nt\n",
    "from neural_tangents import stax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b89718d",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_fn, apply_fn, kernel_fn = stax.serial(\n",
    "    stax.Dense(512), stax.Relu(),\n",
    "    stax.Dense(512), stax.Relu(),\n",
    "    stax.Dense(1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "47d850d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "/tmp/ipykernel_1074742/4235229312.py:6: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n",
      "  param_count = sum(x.size for x in jax.tree_leaves(params))\n"
     ]
    }
   ],
   "source": [
    "n_train = 5000\n",
    "n_test = 1\n",
    "d = 100 \n",
    "key1, key2, key3 = random.split(random.PRNGKey(1), num=3)\n",
    "output_shape, params = init_fn(key1, input_shape=(d,))\n",
    "param_count = sum(x.size for x in jax.tree_leaves(params))\n",
    "delta = jax.numpy.sqrt(param_count/n_train)\n",
    "cov = jax.numpy.identity(d, dtype=None)\n",
    "cov = cov.at[0,0].set(1+delta)\n",
    "mean = jax.numpy.zeros(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "321bdad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = jax.random.multivariate_normal(key2, mean, cov, (n_train,))\n",
    "y_train = random.uniform(key2, shape=(n_train, 1))  # training targets\n",
    "x_test = jax.random.multivariate_normal(key3, mean, cov, (n_test,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aa12bdc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def returnsNTkKernel(x_train, kernel_fn, x=None, train=True):\n",
    "    if train:\n",
    "        ntk_train_train = kernel_fn(x_train, x_train, 'ntk')\n",
    "        return ntk_train_train\n",
    "    ntk_test_train = kernel_fn(x, x_train, 'ntk')\n",
    "    return ntk_test_train\n",
    "\n",
    "def EigenDecompositonOfKernel(ntk_train_train):\n",
    "    eigenvalues, eigenvectors = jax.numpy.linalg.eigh(ntk_train_train)\n",
    "    return eigenvalues, eigenvectors\n",
    "\n",
    "def feature(y_train, ntk_test_train, eigenvalues, eigenvectors, i):\n",
    "    lambda_i, v_i = eigenvalues[i], eigenvectors[i].reshape(eigenvectors[i].shape[0],1)\n",
    "    output = ntk_test_train @ v_i @ v_i.T @ y_train\n",
    "    output = output/lambda_i\n",
    "    return output\n",
    "\n",
    "def returnsXPerturbed(x_test, epsilon, key):\n",
    "    noise = random.normal(key, x_test.shape)\n",
    "    noise_norm = jax.numpy.linalg.norm(noise, axis=1)\n",
    "    noise_norm = noise_norm.reshape(noise_norm.shape[0],1)\n",
    "    normed_noise = epsilon*noise/noise_norm\n",
    "    x_perturbed = x_test+normed_noise\n",
    "    return x_perturbed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bc9b8eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk_train_train = returnsNTkKernel(x_train, kernel_fn, x=None, train=True)\n",
    "eigenvalues, eigenvectors = EigenDecompositonOfKernel(ntk_train_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61ef7e3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "norms={}\n",
    "for epsilon in [0.5, 0.7, 1, 2]:\n",
    "    norms[str(epsilon)] = []\n",
    "    x_perturbed = returnsXPerturbed(x_test, epsilon, key3)\n",
    "    ntk_test_train_perturbed = returnsNTkKernel(x_train, kernel_fn, x_perturbed, train=False)\n",
    "    ntk_test_train_clean = returnsNTkKernel(x_train, kernel_fn, x_test, train=False)\n",
    "    for i in range(x_train.shape[0]):\n",
    "        f_x_delta = feature(y_train, ntk_test_train_perturbed, eigenvalues, eigenvectors, i)\n",
    "        f_x = feature(y_train, ntk_test_train_clean, eigenvalues, eigenvectors, i)\n",
    "        difference = jax.numpy.linalg.norm(f_x_delta-f_x)\n",
    "        norms[str(epsilon)].append(difference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56077def",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "fig = plt.figure(figsize=(15,5))\n",
    "gs = matplotlib.gridspec.GridSpec(1, 2, width_ratios=[2, 2]) \n",
    "ax1 = plt.subplot(gs[0])\n",
    "ax2 = plt.subplot(gs[1])\n",
    "for epsilon in norms.keys():\n",
    "    if epsilon in ['0.5']:\n",
    "        ax1.plot(eigenvalues, norms[epsilon], label=epsilon)\n",
    "        ax2.plot(eigenvalues, label=epsilon)\n",
    "ax1.set_ylabel(\"||f_x_perturbed-f_x||\")\n",
    "ax1.set_xlabel(\"Eigenvalues\")\n",
    "ax2.set_ylabel(\"Eigenvalues\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911d4cc7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
