{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=8\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=8\n",
    "import jax\n",
    "from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad\n",
    "from jax.numpy import linalg as jla\n",
    "from jax.flatten_util import ravel_pytree\n",
    "from neural_tangents import taylor_expand\n",
    "\n",
    "from lax_util import fold, laxmap\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "from flax import linen as nn\n",
    "from flax.linen import initializers as jinit\n",
    "from functools import partial\n",
    "from typing import Callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is default PyTorch init for ReLU\n",
    "torch_init = jinit.variance_scaling(1 / 2, \"fan_in\", \"uniform\")\n",
    "TorchLinear = partial(\n",
    "    nn.Dense, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None\n",
    ")\n",
    "\n",
    "width = 100\n",
    "\n",
    "class sMLP(nn.Module):\n",
    "    sigma: Callable\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = TorchLinear(width)(x)\n",
    "        x = self.sigma(x)\n",
    "        x = TorchLinear(1)(x)\n",
    "        x = x[...,0]\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_loss(d, n, T):\n",
    "\n",
    "    beta = random.normal(fn_rng, (d,))\n",
    "    beta = beta/jla.norm(beta)\n",
    "\n",
    "    def f_star(x):\n",
    "        A = jnp.diag(jnp.array([i > d/2 for i in range(d)]))\n",
    "        return x.T @ A @ x / jnp.sqrt(675.) + ((beta @ x)**3 - 3*(beta @ x))/jnp.sqrt(6)\n",
    "\n",
    "    X_train = random.normal(train_rng, (d, n)).T\n",
    "    y_train = vmap(f_star)(X_train)\n",
    "\n",
    "    X_test = random.normal(test_rng, (d, 10000)).T\n",
    "    y_test = vmap(f_star)(X_test)\n",
    "\n",
    "    sigma = lambda z: jax.nn.relu(z)\n",
    "    model = sMLP(sigma=sigma)\n",
    "    init_params = model.init(model_rng,X_train[:1])\n",
    "    init_params, unravel = ravel_pytree(init_params)\n",
    "    f = lambda p,x: model.apply(unravel(p),x)\n",
    "    lr = 0.05\n",
    "\n",
    "    @jit\n",
    "    def step(params):\n",
    "\n",
    "        loss_fn = lambda W : jnp.mean((y_train - f(W, X_train))**2)\n",
    "        test_loss_fn = lambda W : jnp.mean((y_test - f(W, X_test))**2)\n",
    "        \n",
    "        loss = loss_fn(params)\n",
    "\n",
    "        grads = jax.grad(loss_fn)(params)\n",
    "        test_loss = test_loss_fn(params)\n",
    "        params = params-lr*grads\n",
    "\n",
    "        return dict(state=params,save=(loss, test_loss))\n",
    "\n",
    "    params = init_params\n",
    "\n",
    "    res = fold(step,params,steps=T,show_progress=True)\n",
    "    return res['save'][1][-1]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_best_n(d, threshhold=0.1):\n",
    "    ns = [i*int(d**2) for i in range(1, 21)]\n",
    "    for n in ns:\n",
    "        print(d, n)\n",
    "        test_loss = get_test_loss(d, n, 30000)\n",
    "        if test_loss < threshhold:\n",
    "            return n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns_nostop = []\n",
    "for d in [10, 20, 30, 40, 50, 100]:\n",
    "    ns_nostop.append(get_best_n(d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,axs = plt.subplots(1, 1,figsize=(6, 4))\n",
    "axs = np.ravel(axs)\n",
    "\n",
    "plt.rcParams['font.size'] = '16'\n",
    "\n",
    "plt.sca(axs[0])\n",
    "ds = [10, 20, 30, 40, 50, 100]\n",
    "plt.plot(ds, [6*d**2 for d in ds], linestyle = 'dotted', color='gray', label=r'$n = 6d^2$')\n",
    "plt.plot(ds, [d**3 for d in ds], linestyle = 'dashed', color='red', label=r'$n = d^3$')\n",
    "plt.plot(ds, ns_nostop, marker = 'o', color = 'blue', label='samples')\n",
    "plt.yscale(\"log\")\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(r'$d$')\n",
    "plt.ylabel(r'$n$')\n",
    "plt.legend()\n",
    "plt.xticks(ds, ['10', '20', '30', '40', '50', '100'])\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a55c72e2be434f5765594e9ed5464ac27f0df6c2513a32031c896f1652198cf7"
  },
  "kernelspec": {
   "display_name": "Python 3.8.13 ('jax')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
