{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7681e8d-02c8-4f58-afe7-6b02a218bc59",
   "metadata": {},
   "source": [
    "Please refer to [Neural Tangents](https://github.com/google/neural-tangents) for dependencies.  \n",
    "You may need [JAX with GPU Support](https://jax.readthedocs.io/en/latest/installation.html#install-nvidia-gpu)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "678b4650-e0d4-4d02-8811-800f2663269e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import random\n",
    "from neural_tangents import stax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import neural_tangents as nt\n",
    "\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e8649b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################\n",
    "# Data Generation\n",
    "\n",
    "width = 2\n",
    "margin = 1\n",
    "\n",
    "def get_Y(x, noise = 0):\n",
    "    noise_vec = np.random.normal(scale = noise, size = x.shape)\n",
    "    return np.sin(1.4 * x) + 0.1 * np.sin(10 * x) + noise_vec\n",
    "\n",
    "def get_train_X(N = 1024):\n",
    "    return np.concatenate([np.random.rand(N,) * width - (margin + width),\n",
    "                           np.random.rand(N,) * width +  margin,\n",
    "                           # np.array([0])\n",
    "           ])\n",
    "\n",
    "def get_val_X(N = 1024):\n",
    "    return np.linspace(-5, 5, num = N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e8bde83-0c32-40bd-8e8d-3ad9272956a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################\n",
    "# Network definition\n",
    "\n",
    "hidden_dim = 2048\n",
    "in_dim = 1\n",
    "\n",
    "init_fn, apply_fn, kernel_fn = stax.serial(\n",
    "    stax.Dense(hidden_dim, W_std=2.0, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    stax.Dense(hidden_dim, W_std=2.0, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    stax.Dense(1)\n",
    ")\n",
    "\n",
    "key = random.PRNGKey(1)\n",
    "# x = random.normal(key, (10, 1))\n",
    "\n",
    "#########################################\n",
    "# Data generation\n",
    "\n",
    "ntx = get_train_X(N = 5)\n",
    "nvx = get_val_X(N = 256)\n",
    "\n",
    "x_tr = jnp.array(ntx)[:, None]\n",
    "x_va = jnp.array(nvx)[:, None]\n",
    "y_tr = jnp.array(get_Y(ntx, 0.03))[:, None]\n",
    "y_va = jnp.array(get_Y(nvx))[:, None]\n",
    "\n",
    "_, params = init_fn(key, input_shape=x_tr.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "99fababe-14a5-45f8-9c02-9d99a18103c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "################### Direct computation of Linearized dynamics on MSE #####################\n",
    "\n",
    "predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_tr, y_tr, diag_reg = 1e-3)\n",
    "# y_val_nngp = predict_fn(x_test = x_va, get = 'nngp')\n",
    "\n",
    "y_val_ntk, y_val_ntk_cov = predict_fn(x_test = x_va, get = 'ntk', compute_cov = True)\n",
    "# y_val_ntk = predict_fn(x_test = x_va, get = 'nngp', compute_cov = True)\n",
    "\n",
    "y = apply_fn(params, x_va)\n",
    "y_var = kernel_fn(x_va, x_va, 'nngp')\n",
    "\n",
    "# print(x_va)\n",
    "# print(y_val_ntk)\n",
    "\n",
    "ntk_mean = jnp.reshape(y_val_ntk, (-1,))\n",
    "ntk_std = jnp.sqrt(jnp.diag(y_val_ntk_cov))\n",
    "\n",
    "cov_scale = 2\n",
    "y_lower = ntk_mean - ntk_std * cov_scale\n",
    "y_upper = ntk_mean + ntk_std * cov_scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "c5220a9d-3b68-458e-9621-cc323611badc",
   "metadata": {},
   "outputs": [],
   "source": [
    "##################### Upper bound by Gronwall Inequality ######################\n",
    "\n",
    "Ozz = jnp.diag(kernel_fn(x_va, x_va, 'ntk'))\n",
    "Oxx = jnp.diag(kernel_fn(x_tr, x_tr, 'ntk'))\n",
    "Ozx = kernel_fn(x_va, x_tr, 'ntk')\n",
    "\n",
    "addition_term = jnp.min(Oxx[None, :] - 2 * Ozx, axis = 1)\n",
    "UB = Ozz + addition_term\n",
    "\n",
    "# UB = 0.2 * jnp.exp(2 * jnp.sqrt(UB) - 1)\n",
    "UB = 0.6 * jnp.sqrt(UB)\n",
    "\n",
    "mean_term = jnp.mean(Oxx[None, :] - 2 * Ozx, axis = 1)\n",
    "UB_mean = Ozz + mean_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36a1845-0fcb-46c8-bba0-6c8c9b54979b",
   "metadata": {},
   "outputs": [],
   "source": [
    "############################### VISUALIZATION #################################\n",
    "plt.style.use('./style.mplstyle')\n",
    "plt.tight_layout()\n",
    "# plt.figure(figsize = (3.5, 1.5))\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (2.8, 1.5))\n",
    "\n",
    "plt.xlim(-4.5, 4.5)\n",
    "\n",
    "plt.grid(linestyle = '--', linewidth = 0.5, alpha = 0.5)\n",
    "\n",
    "# plt.fill_between(x_va[:, 0], ntk_mean - Ozz, ntk_mean + Ozz, alpha = 0.04, color = 'C3', label = \"O(z,z)\")\n",
    "# plt.fill_between(x_va[:, 0], ntk_mean - UB_mean, ntk_mean + UB_mean, alpha = 0.14, color = 'C0', label = \"Upper bound (mean)\")\n",
    "plt.fill_between(x_va[:, 0], ntk_mean - UB, ntk_mean + UB, alpha = 0.3, color = 'C1', label = \"Upper bound\")\n",
    "plt.fill_between(x_va[:, 0], y_lower, y_upper, alpha = 0.7, color = 'C0', label = \"Exact trained ensemble\")\n",
    "\n",
    "plt.plot(x_va[:, 0], ntk_mean, linewidth=0.65, label = \"trained\")\n",
    "plt.plot(x_tr[:, 0], y_tr[:, 0], 'o', markersize=2, color='k', label = \"dataset\")\n",
    "plt.plot(x_va[:, 0], y_va[:, 0], '--', linewidth=0.65, alpha = 0.4, color='C2', label = \"target\")\n",
    "\n",
    "# ax.text(.1,.88,'\\\\textbf{a)}',\n",
    "#         horizontalalignment='center',\n",
    "#         transform=ax.transAxes)\n",
    "\n",
    "# plt.legend()\n",
    "# plt.savefig(\"test.png\", dpi=300)\n",
    "plt.savefig(\"test.pdf\", format=\"pdf\", bbox_inches = 'tight')\n",
    "# plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
