{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=7\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=7\n",
    "import jax\n",
    "from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad\n",
    "from jax.numpy import linalg as jla\n",
    "from jax.flatten_util import ravel_pytree\n",
    "from neural_tangents import taylor_expand\n",
    "\n",
    "from lax_util import fold, laxmap\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "import scipy\n",
    "\n",
    "import numpy as np\n",
    "import torchvision\n",
    "import torch\n",
    "\n",
    "from flax import linen as nn\n",
    "from typing import Callable\n",
    "\n",
    "from flax import linen as nn\n",
    "from flax.linen import initializers as jinit\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is default PyTorch init for ReLU\n",
    "torch_init = jinit.variance_scaling(1 / 2, \"fan_in\", \"uniform\")\n",
    "\n",
    "TorchLinear = partial(\n",
    "    nn.Dense, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None\n",
    ")\n",
    "TorchConv = partial(nn.Conv, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None)\n",
    "\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    sigma: Callable\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        x = TorchConv(512, (3, 3))(x)\n",
    "        x = self.sigma(x)\n",
    "        x = nn.avg_pool(x, (2, 2), (2, 2), \"SAME\")\n",
    "        x = TorchConv(512, (3, 3))(x)\n",
    "        x = self.sigma(x)\n",
    "        x = nn.avg_pool(x, (2, 2), (2, 2), \"SAME\")\n",
    "        x = TorchConv(512, (3, 3))(x)\n",
    "        x = self.sigma(x)\n",
    "        x = nn.avg_pool(x, (2, 2), (2, 2), \"SAME\")\n",
    "        x = TorchConv(512, (3, 3))(x)\n",
    "        x = self.sigma(x)\n",
    "        x = nn.avg_pool(x, (2, 2), (2, 2), \"SAME\")\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = TorchLinear(1)(x)[...,0]\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = \"~/datasets/\"\n",
    "\n",
    "traindata = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True)\n",
    "testdata = torchvision.datasets.CIFAR10(DATA_DIR, train=False, download=True)\n",
    "\n",
    "\n",
    "d = 3*32*32\n",
    "\n",
    "class0 = 3\n",
    "class1 = 7\n",
    "\n",
    "train_x = np.array(traindata.data)\n",
    "train_y = np.array(traindata.targets)\n",
    "\n",
    "idx = np.logical_or(train_y == class0, train_y == class1)\n",
    "\n",
    "train_x = train_x - np.array([125.30691805, 122.95039414, 113.86538318])\n",
    "train_x = train_x / np.array([62.99321928, 62.08870764, 66.70489964])\n",
    "train_x = train_x[idx]\n",
    "train_y = train_y[idx]\n",
    "train_y = np.array(train_y == class0, dtype=int)*2 - 1\n",
    "\n",
    "\n",
    "test_x = np.array(testdata.data)\n",
    "test_y = np.array(testdata.targets)\n",
    "\n",
    "idx = np.logical_or(test_y == class0, test_y == class1)\n",
    "test_x = test_x - np.array([125.30691805, 122.95039414, 113.86538318])\n",
    "test_x = test_x / np.array([62.99321928, 62.08870764, 66.70489964])\n",
    "test_x = test_x[idx]\n",
    "test_y = test_y[idx]\n",
    "test_y = np.array(test_y == class0, dtype=int)*2 - 1\n",
    "\n",
    "print(train_x.shape)\n",
    "print(test_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = 10000\n",
    "n_test = 2000\n",
    "n = n_train\n",
    "m = 100\n",
    "T = 10000\n",
    "lr = 5e-3\n",
    "print(d, n_train, m)\n",
    "\n",
    "seed = 1\n",
    "\n",
    "X_train = train_x\n",
    "y_train = train_y\n",
    "\n",
    "X_test = test_x\n",
    "y_test = test_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_tot = {}\n",
    "model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)\n",
    "\n",
    "sigma = lambda z: jax.nn.relu(z)\n",
    "\n",
    "model = CNN(sigma=sigma)\n",
    "init_params = model.init(model_rng,train_x[:1])\n",
    "init_params, unravel = ravel_pytree(init_params)\n",
    "f = lambda p,x: model.apply(unravel(p),x)\n",
    "\n",
    "@jit\n",
    "def full(x, W):\n",
    "    return f(W, x)\n",
    "\n",
    "def lin_plus_quad(x, W):\n",
    "    return taylor_expand(lambda params: f(params, x), init_params, 2)(W)\n",
    "\n",
    "def linearize(x, W):\n",
    "    return taylor_expand(lambda params: f(params, x), init_params, 1)(W)\n",
    "\n",
    "def ntk_feature(x):\n",
    "    return jax.grad(lambda params: f(params, x))(init_params)\n",
    "\n",
    "\n",
    "full_loss_fn = lambda W: jnp.mean((y_train - full(X_train, W))**2)\n",
    "full_test_loss_fn = lambda W: jnp.mean((y_test - full(X_test, W))**2)\n",
    "\n",
    "quad_loss_fn = lambda W: jnp.mean((y_train - lin_plus_quad(X_train, W))**2)\n",
    "quad_test_loss_fn = lambda W: jnp.mean((y_test - lin_plus_quad(X_test, W))**2)\n",
    "\n",
    "lin_loss_fn = lambda W: jnp.mean((y_train - linearize(X_train, W))**2)\n",
    "lin_test_loss_fn = lambda W: jnp.mean((y_test - linearize(X_test, W))**2)    \n",
    "\n",
    "num_epochs = 100\n",
    "batch_size = 64\n",
    "\n",
    "@jit\n",
    "def full_step(input):\n",
    "\n",
    "    params, key = input\n",
    "    full_loss = 0.\n",
    "    T = int(n_train/batch_size)\n",
    "    for it in range(T):\n",
    "        X_batch = X_train[it*batch_size:(it+1)*batch_size]\n",
    "        y_batch = y_train[it*batch_size:(it+1)*batch_size]\n",
    "\n",
    "        loss_fn = lambda W: jnp.mean((y_batch - full(X_batch, W))**2)\n",
    "    \n",
    "        loss = loss_fn(params)\n",
    "        full_loss = full_loss + loss\n",
    "        grads = jax.grad(loss_fn)(params)\n",
    "        params = params-lr*grads\n",
    "\n",
    "    test_loss = full_test_loss_fn(params)\n",
    "    quad_only = lin_plus_quad(X_test, params) - linearize(X_test, params)\n",
    "    quad_corr = quad_test_loss_fn(params)\n",
    "    lin_corr = lin_test_loss_fn(params)\n",
    "\n",
    "    test_acc = jnp.sum(full(X_test, params)*y_test > 0)/2000.\n",
    "    lin_acc = jnp.sum(linearize(X_test, params)*y_test > 0)/2000.\n",
    "    quad_acc = jnp.sum((lin_plus_quad(X_test, params) - linearize(X_test, params))*y_test > 0)/2000.\n",
    "\n",
    "    return dict(state=(params, key),save=(full_loss/T, test_loss, test_acc, lin_acc, quad_acc))\n",
    "\n",
    "@jit\n",
    "def lin_step(input):\n",
    "\n",
    "    params, key = input\n",
    "    full_loss = 0.\n",
    "    T = int(n_train/batch_size)\n",
    "    for it in range(T):\n",
    "        X_batch = X_train[it*batch_size:(it+1)*batch_size]\n",
    "        y_batch = y_train[it*batch_size:(it+1)*batch_size]\n",
    "\n",
    "        loss_fn = lambda W: jnp.mean((y_batch - linearize(X_batch, W))**2)\n",
    "    \n",
    "        loss = loss_fn(params)\n",
    "        full_loss = full_loss + loss\n",
    "        grads = jax.grad(loss_fn)(params)\n",
    "\n",
    "        params = params-lr*grads\n",
    "\n",
    "    test_loss = jnp.mean((y_test - linearize(X_test, params))**2)    \n",
    "\n",
    "\n",
    "    return dict(state=(params, key),save=(full_loss/T, test_loss))\n",
    "\n",
    "@jit\n",
    "def quad_step(input):\n",
    "    params, key = input\n",
    "    full_loss=0.\n",
    "    T = int(n_train/batch_size)\n",
    "    for it in range(T):\n",
    "        X_batch = X_train[it*batch_size:(it+1)*batch_size]\n",
    "        y_batch = y_train[it*batch_size:(it+1)*batch_size]\n",
    "\n",
    "        loss_fn = lambda W: jnp.mean((y_batch - lin_plus_quad(X_batch, W))**2)\n",
    "    \n",
    "        loss = loss_fn(params)\n",
    "        full_loss = full_loss + loss\n",
    "        grads = jax.grad(loss_fn)(params)\n",
    "\n",
    "        params = params-lr*grads\n",
    "\n",
    "    test_loss = jnp.mean((y_test - lin_plus_quad(X_test, params))**2)    \n",
    "\n",
    "    return dict(state=(params, key),save=(full_loss/T, test_loss))\n",
    "\n",
    "\n",
    "params = init_params\n",
    "\n",
    "full_res = {}\n",
    "full_res = fold(lambda input: full_step(input),(params, key),steps=num_epochs,show_progress=True)\n",
    "filename = 'full_data.npy'\n",
    "pickle.dump(full_res, open(filename, 'wb'))\n",
    "\n",
    "lin_res = {}\n",
    "lin_res = fold(lambda input: lin_step(input),(params, key),steps=num_epochs,show_progress=True)\n",
    "filename = 'lin_data.npy'\n",
    "pickle.dump(lin_res, open(filename, 'wb'))\n",
    "\n",
    "quad_res = {}\n",
    "quad_res = fold(lambda input: quad_step(input),(params, key),steps=num_epochs,show_progress=True)\n",
    "filename = 'quad_data.npy'\n",
    "pickle.dump(quad_res, open(filename, 'wb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = full_res['state'][0]\n",
    "np.sum((full(X_test, p))*y_test > 0)/2000.\n",
    "np.sum((linearize(X_test, p))*y_test > 0)/2000.\n",
    "np.sum((lin_plus_quad(X_test, p))*y_test > 0)/2000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(full_test_loss_fn(p))\n",
    "print(lin_test_loss_fn(p))\n",
    "print(quad_test_loss_fn(p))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.size'] = '16'\n",
    "\n",
    "fig,axs = plt.subplots(1, 3,figsize=(15, 5))\n",
    "axs = np.ravel(axs)\n",
    "\n",
    "colors = plt.get_cmap('hsv')(np.linspace(0, 1, 3))\n",
    "\n",
    "plt.sca(axs[0])\n",
    "plt.ylabel(\"Train Loss\")\n",
    "plt.plot(full_res['save'][0], color = 'orange', label=r'$f$')\n",
    "plt.plot(quad_res['save'][0], color = 'blue',label=r'$f_L + f_Q$')\n",
    "plt.plot(lin_res['save'][0], color = 'purple',label=r'$f_L$')\n",
    "plt.xlabel(\"epochs\")\n",
    "plt.ylim(ymin=0.)\n",
    "plt.legend()\n",
    "\n",
    "plt.sca(axs[1])\n",
    "plt.ylabel(\"Test Loss\")\n",
    "plt.plot(full_res['save'][1], color = 'orange',label=r'$f$')\n",
    "plt.plot(quad_res['save'][1], color = 'blue',label=r'$f_L + f_Q$')\n",
    "plt.plot(lin_res['save'][1], color = 'purple',label=r'$f_L$')\n",
    "plt.xlabel(\"epochs\")\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('jax')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a55c72e2be434f5765594e9ed5464ac27f0df6c2513a32031c896f1652198cf7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
