{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=7\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=7\n",
    "import jax\n",
    "from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad\n",
    "from jax.numpy import linalg as jla\n",
    "from jax.flatten_util import ravel_pytree\n",
    "from neural_tangents import taylor_expand\n",
    "\n",
    "from lax_util import fold, laxmap\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100 1000 10000\n"
     ]
    }
   ],
   "source": [
    "d = 100\n",
    "n = int(d**1.5)\n",
    "m = 10000\n",
    "T = 20000\n",
    "lr = 1.\n",
    "print(d, n, m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [1, 11, 111, 1111, 11111]\n",
    "lambs = [0.0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in seeds:\n",
    "    res_tot = {}\n",
    "    model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)\n",
    "\n",
    "    W_0 = random.normal(model_rng, (d, int(m/2)))\n",
    "    W_0 = W_0/jla.norm(W_0, axis=0)\n",
    "    W_0 = jnp.concatenate([W_0, W_0], axis=1)\n",
    "    a = jnp.concatenate([jnp.ones(int(m/2)), -jnp.ones(int(m/2))])/jnp.sqrt(m)\n",
    "    sigma = lambda z: jax.nn.sigmoid(z - 1)\n",
    "\n",
    "    @jit\n",
    "    def net(x, W):\n",
    "        W_mat = W.reshape(d, m)\n",
    "        return a @ sigma((W_0 + W_mat).T @ x)\n",
    "\n",
    "    @jit\n",
    "    def lin_plus_quad(x, W):\n",
    "        return taylor_expand(lambda params: net(x, params), jnp.zeros(d*m), 2)(W)\n",
    "\n",
    "    def linearize(x, W):\n",
    "        return taylor_expand(lambda params: net(x, params), jnp.zeros(d*m), 1)(W)\n",
    "\n",
    "    def ntk_feature(x):\n",
    "        return jax.grad(lambda params: net(x, params))(jnp.zeros(d*m))\n",
    "\n",
    "    # generate low-rank quad + lin\n",
    "    beta = random.normal(fn_rng, (d,))\n",
    "    beta = beta/jla.norm(beta)\n",
    "    def f_star(x):\n",
    "        return ((beta @ x)**2 - 1 + (beta @ x))/jnp.sqrt(3)\n",
    "\n",
    "    X_train = random.normal(train_rng, (d, n))\n",
    "    y_train = vmap(f_star)(X_train.T)\n",
    "    loss_fn = lambda W : jnp.mean((y_train - lin_plus_quad(X_train, W))**2)\n",
    "\n",
    "    X_test = random.normal(test_rng, (d, 10000))\n",
    "    y_test = vmap(f_star)(X_test.T)\n",
    "    test_loss_fn = lambda W : jnp.mean((y_test - lin_plus_quad(X_test, W))**2)\n",
    "\n",
    "    n_k = d # how many directions to keep?\n",
    "    # compute Jacobian:\n",
    "    J = vmap(ntk_feature)(X_train.T)\n",
    "    U, S, Vh = jla.svd(J, full_matrices=False)\n",
    "\n",
    "    J_project = U @ jnp.diag(S * jnp.array([1. if i < n_k else 0. for i in range(n)])) @ Vh\n",
    "    J_large = J - J_project\n",
    "\n",
    "    linear_fn = lambda W : jla.norm(linearize(X_train, W))**2/n\n",
    "    reg_fn = lambda W: (J_large @ W).T @ (J_large @ W)/n\n",
    "\n",
    "    @jit\n",
    "    def reg_step(input, lamb):\n",
    "        \n",
    "        params, key = input\n",
    "        loss = loss_fn(params)\n",
    "        \n",
    "        # # randomized R_1\n",
    "        # key, subkey = random.split(key, 2)\n",
    "        # X_sample = random.normal(subkey, (d, 100))\n",
    "        # r1 = lambda W: jnp.mean(linearize(X_sample, W - Vh[:d,].T @ (Vh[:d,] @ W))**2)\n",
    "        # r1_val = r1(params)\n",
    "\n",
    "        grads = jax.grad(loss_fn)(params) + 0.01*jax.grad(reg_fn)(params) + lamb*jax.grad(r1)(params)\n",
    "        test_loss = test_loss_fn(params)\n",
    "\n",
    "        reg = reg_fn(params)\n",
    "\n",
    "        linear = linear_fn(params)\n",
    "        params = params-lr*grads\n",
    "\n",
    "        return dict(state=(params, key),save=(loss, test_loss, reg, linear))\n",
    "\n",
    "    params = jnp.zeros((d*m))\n",
    "\n",
    "    \n",
    "    for lamb in lambs:\n",
    "        res = fold(lambda input: reg_step(input, lamb),(params, key),steps=T,show_progress=True)\n",
    "        res_tot[lamb] = res['save']\n",
    "    \n",
    "    filename = 'seed' + str(seed) + 'data.npy'\n",
    "    pickle.dump(res_tot, open(filename, 'wb'))\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a55c72e2be434f5765594e9ed5464ac27f0df6c2513a32031c896f1652198cf7"
  },
  "kernelspec": {
   "display_name": "Python 3.8.13 ('jax')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
