{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7681e8d-02c8-4f58-afe7-6b02a218bc59",
   "metadata": {},
   "source": [
    "Please refer to [Neural Tangents](https://github.com/google/neural-tangents) for dependencies.  \n",
    "You may need [JAX with GPU Support](https://jax.readthedocs.io/en/latest/installation.html#install-nvidia-gpu)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "678b4650-e0d4-4d02-8811-800f2663269e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import random\n",
    "from neural_tangents import stax\n",
    "import jax.numpy as jnp\n",
    "import neural_tangents as nt\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from sklearn.datasets import make_moons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e8bde83-0c32-40bd-8e8d-3ad9272956a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################\n",
    "# Network definition\n",
    "\n",
    "hidden_dim = 2048\n",
    "in_dim = 2\n",
    "\n",
    "init_fn, apply_fn, kernel_fn = stax.serial(\n",
    "    stax.Dense(hidden_dim, W_std=2.0, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    stax.Dense(hidden_dim, W_std=2.0, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    # stax.Dense(hidden_dim, W_std=1.5, b_std=0.05), stax.Erf(),# stax.Erf(),\n",
    "    stax.Dense(1)\n",
    ")\n",
    "\n",
    "key = random.PRNGKey(1)\n",
    "# x = random.normal(key, (10, 1))\n",
    "\n",
    "#########################################\n",
    "# Data generation\n",
    "\n",
    "X_min = -2.25\n",
    "X_max = 2.25\n",
    "Y_min = -2.25\n",
    "Y_max = 2.25\n",
    "N_val = 64\n",
    "N_tra = 32\n",
    "\n",
    "X, y = make_moons(n_samples = N_tra, noise = 0.2, random_state = 330)\n",
    "X = X + np.array([-0.5, -0.25])\n",
    "val_x = np.meshgrid(np.linspace(X_min, X_max, num = N_val), np.linspace(Y_min, Y_max, num = N_val))\n",
    "val_x = np.stack(val_x, axis = -1).reshape(N_val ** 2, 2)\n",
    "\n",
    "x_tr = jnp.array(X)\n",
    "x_va = jnp.array(val_x)\n",
    "y_tr = jnp.array(y)[:, None]\n",
    "# y_va = jnp.array(dataloaders.get_Y(nvx))[:, None]\n",
    "\n",
    "plt.scatter(x_tr[:, 0], x_tr[:, 1], s = 12, color = \"#fe4a49\", alpha = (y_tr == 0), marker = '^')\n",
    "plt.scatter(x_tr[:, 0], x_tr[:, 1], s = 30, color = \"#fcba04\", alpha = (y_tr == 1), marker = '+')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c2bd19-68cc-4732-8ed2-54f383fbbcdd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(y_tr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "99fababe-14a5-45f8-9c02-9d99a18103c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "################### Direct computation of Linearized dynamics on BCE #####################\n",
    "\n",
    "from jax.nn import log_sigmoid\n",
    "\n",
    "k_tr_tr = kernel_fn(x_tr, x_tr, 'ntk')\n",
    "lr = 1e-2\n",
    "momentum = 0.9\n",
    "\n",
    "# predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_tr, y_tr, diag_reg = 5e-1)\n",
    "loss_fn = lambda fx, y_hat: -jnp.mean(y_hat * log_sigmoid(fx) + (1 - y_hat) * log_sigmoid(-fx))\n",
    "predict_fn = nt.predict.gradient_descent(\n",
    "    loss_fn, k_tr_tr, y_tr, lr, momentum\n",
    ")\n",
    "# y_val_nngp = predict_fn(x_test = x_va, get = 'nngp')\n",
    "\n",
    "# y_val_ntk, y_val_ntk_cov = predict_fn(x_test = x_va, get = 'ntk', compute_cov = True)\n",
    "# # y_val_ntk = predict_fn(x_test = x_va, get = 'nngp', compute_cov = True)\n",
    "\n",
    "# y = apply_fn(params, x_va)\n",
    "# y_var = kernel_fn(x_va, x_va, 'nngp')\n",
    "\n",
    "# # print(x_va)\n",
    "# # print(y_val_ntk)\n",
    "\n",
    "# ntk_mean = jnp.reshape(y_val_ntk, (-1,))\n",
    "# ntk_std = jnp.sqrt(jnp.diag(y_val_ntk_cov))\n",
    "\n",
    "# cov_scale = 2\n",
    "# y_lower = ntk_mean - ntk_std * cov_scale\n",
    "# y_upper = ntk_mean + ntk_std * cov_scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f7f8a9cb-3d44-415d-9eba-70a8650d50ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ozz = jnp.diag(kernel_fn(x_va, x_va, 'ntk'))\n",
    "Oxx = jnp.diag(kernel_fn(x_tr, x_tr, 'ntk'))\n",
    "Ozx = kernel_fn(x_va, x_tr, 'ntk')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c591f2b8-7526-4e72-a9dd-99b703965429",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "t = 6e3\n",
    "n_samples = 20\n",
    "\n",
    "y_vals = []\n",
    "for i in tqdm(range(n_samples)):\n",
    "    key, net_key = random.split(key)\n",
    "    _, params = init_fn(key, input_shape=x_tr.shape)\n",
    "    y_tr_init = apply_fn(params, x_tr)\n",
    "    y_va_init = apply_fn(params, x_va)\n",
    "    # help(predict_fn)\n",
    "    _, y_val_ntk = predict_fn(t, y_tr_init, y_va_init, Ozx)\n",
    "    y_vals.append(y_val_ntk.reshape(N_val, N_val))\n",
    "\n",
    "y_vals = np.stack(y_vals)\n",
    "ntk_std = y_vals.std(0)\n",
    "print(ntk_std.shape)\n",
    "# plt.imshow(y_val_ntk.reshape(N_val, N_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5220a9d-3b68-458e-9621-cc323611badc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "##################### Upper bound by Gronwall Inequality ######################\n",
    "\n",
    "addition_term = jnp.min(Oxx[None, :] - 2 * Ozx, axis = 1)\n",
    "UB = Ozz + addition_term\n",
    "\n",
    "# UB = 0.2 * jnp.exp(2 * jnp.sqrt(UB) - 1)\n",
    "UB = 0.6 * jnp.sqrt(UB + 1e-2)\n",
    "\n",
    "mean_term = jnp.mean(Oxx[None, :] - 2 * Ozx, axis = 1)\n",
    "UB_mean = Ozz + mean_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9e509a08-6e0a-4411-909d-b2e91e2a3a1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "def create_custom_colormap(color1, color2):\n",
    "    \"\"\"\n",
    "    Create a custom colormap from two input colors.\n",
    "    \n",
    "    Parameters:\n",
    "    color1 (str): The starting color of the colormap.\n",
    "    color2 (str): The ending color of the colormap.\n",
    "    \n",
    "    Returns:\n",
    "    matplotlib.colors.LinearSegmentedColormap: A custom colormap.\n",
    "    \"\"\"\n",
    "    # Create a list of colors for the colormap\n",
    "    colors = [color1, color2]\n",
    "    \n",
    "    # Create the colormap\n",
    "    cmap = mcolors.LinearSegmentedColormap.from_list(\"custom_cmap\", colors)\n",
    "    \n",
    "    return cmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36a1845-0fcb-46c8-bba0-6c8c9b54979b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import FixedLocator, FuncFormatter\n",
    "############################### VISUALIZATION #################################\n",
    "plt.style.use('./style.mplstyle')\n",
    "plt.tight_layout()\n",
    "# plt.figure(figsize = (2, 2))\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)\n",
    "\n",
    "# Remove top and bottom spines from top and bottom subplots respectively\n",
    "# ax1.spines['bottom'].set_visible(False)\n",
    "# ax2.spines['top'].set_visible(False)\n",
    "\n",
    "# Adjust the positions of the subplots\n",
    "fig.subplots_adjust(hspace=0)\n",
    "\n",
    "# Get the current positions\n",
    "pos1 = ax1.get_position()\n",
    "pos2 = ax2.get_position()\n",
    "\n",
    "# ax1.set_xticks([])\n",
    "\n",
    "# Set new positions\n",
    "ax1.set_position([pos1.x0, pos1.y0, pos1.width, pos1.height * 1.005])\n",
    "ax2.set_position([pos2.x0, pos2.y0, pos2.width, pos2.height * 1.005])\n",
    "\n",
    "im_cmap = create_custom_colormap('#ffffff', '#0496ff')\n",
    "\n",
    "ax1.scatter(x_tr[:, 0], x_tr[:, 1], s = 4, color = \"#ff4f4f\", alpha = (y_tr == 0), marker = '^')\n",
    "ax1.scatter(x_tr[:, 0], x_tr[:, 1], s = 15, color = \"#d6f5ff\", alpha = (y_tr == 1), marker = '+')\n",
    "\n",
    "ax2.scatter(x_tr[:, 0], x_tr[:, 1], s = 4, color = \"#ff4f4f\", alpha = (y_tr == 0), marker = '^')\n",
    "ax2.scatter(x_tr[:, 0], x_tr[:, 1], s = 15, color = \"#d6f5ff\", alpha = (y_tr == 1), marker = '+')\n",
    "# plt.imshow(y_vals[2].reshape(N_val, N_val), origin = 'lower', extent = (X_min, X_max, Y_min, Y_max), interpolation = 'lanczos')\n",
    "ax1.imshow(ntk_std, origin = 'lower', extent = (X_min, X_max, Y_min, Y_max), interpolation = 'lanczos')\n",
    "ax2.imshow((UB.reshape(N_val, N_val)), origin = 'lower', extent = (X_min, X_max, X_min, X_max))\n",
    "# plt.colorbar()\n",
    "\n",
    "# ax1.set_yticklabels(['?', '-2', '0', '2'], va='bottom')\n",
    "ax2.set_yticklabels(['?', '-2', '0', '2'], va='top')\n",
    "\n",
    "# plt.legend()\n",
    "# plt.savefig(\"test.png\", dpi=300)\n",
    "plt.savefig(\"toy-classification-both.pdf\", format=\"pdf\", bbox_inches = 'tight')\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6557e83b-a14a-47ef-9d40-381e64cd291a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
