{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "176b5eec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-17 08:55:00.105153: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from jax import random\n",
    "import jax\n",
    "import neural_tangents as nt\n",
    "from neural_tangents import stax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8575d9bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_fn, apply_fn, kernel_fn = stax.serial(\n",
    "    stax.Dense(10,parameterization = 'ntk')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dca60328",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-17 08:55:17.861268: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:397] There was an error before creating cudnn handle: cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.\n",
      "2023-02-17 08:55:17.861407: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR\n",
      "2023-02-17 08:55:17.862945: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:57] INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:626) dnn != nullptr \n",
      "*** Begin stack trace ***\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\t\n",
      "\tPyCFunction_Call\n",
      "\t_PyObject_MakeTpCall\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyObject_FastCallDict\n",
      "\t\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\tPyEval_EvalCode\n",
      "\t\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyFunction_Vectorcall\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t_PyFunction_Vectorcall\n",
      "\t\n",
      "\tPyObject_Call\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t_PyEval_EvalCodeWithName\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t_PyEval_EvalFrameDefault\n",
      "\t\n",
      "\t\n",
      "*** End stack trace ***\n",
      "\n"
     ]
    },
    {
     "ename": "XlaRuntimeError",
     "evalue": "INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:626) dnn != nullptr ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mXlaRuntimeError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m n_test \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m      3\u001b[0m d \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m100\u001b[39m \n\u001b[0;32m----> 4\u001b[0m key1, key2, key3 \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39msplit(\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m, num\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n\u001b[1;32m      5\u001b[0m output_shape, params \u001b[38;5;241m=\u001b[39m init_fn(key1, input_shape\u001b[38;5;241m=\u001b[39m(d,))\n\u001b[1;32m      6\u001b[0m delta \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mnumpy\u001b[38;5;241m.\u001b[39msqrt(d\u001b[38;5;241m/\u001b[39mn_train) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m10\u001b[39m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/random.py:136\u001b[0m, in \u001b[0;36mPRNGKey\u001b[0;34m(seed)\u001b[0m\n\u001b[1;32m    133\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m np\u001b[38;5;241m.\u001b[39mndim(seed):\n\u001b[1;32m    134\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPRNGKey accepts a scalar seed, but was given an array of\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    135\u001b[0m                   \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshape \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mshape(seed)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m != (). Use jax.vmap for batching\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 136\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[43mprng\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mseed_with_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimpl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    137\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _return_prng_keys(\u001b[38;5;28;01mTrue\u001b[39;00m, key)\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/prng.py:267\u001b[0m, in \u001b[0;36mseed_with_impl\u001b[0;34m(impl, seed)\u001b[0m\n\u001b[1;32m    266\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mseed_with_impl\u001b[39m(impl: PRNGImpl, seed: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m PRNGKeyArray:\n\u001b[0;32m--> 267\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrandom_seed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/prng.py:570\u001b[0m, in \u001b[0;36mrandom_seed\u001b[0;34m(seeds, impl)\u001b[0m\n\u001b[1;32m    568\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    569\u001b[0m   seeds_arr \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39masarray(seeds)\n\u001b[0;32m--> 570\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrandom_seed_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseeds_arr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:343\u001b[0m, in \u001b[0;36mPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m    340\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams):\n\u001b[1;32m    341\u001b[0m   \u001b[38;5;28;01massert\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m config\u001b[38;5;241m.\u001b[39mjax_enable_checks \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m    342\u001b[0m           \u001b[38;5;28mall\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(arg, Tracer) \u001b[38;5;129;01mor\u001b[39;00m valid_jaxtype(arg) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)), args\n\u001b[0;32m--> 343\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfind_top_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:346\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m    345\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind_with_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, trace, args, params):\n\u001b[0;32m--> 346\u001b[0m   out \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull_raise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    347\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(full_lower, out) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmultiple_results \u001b[38;5;28;01melse\u001b[39;00m full_lower(out)\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:728\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m    727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprocess_primitive\u001b[39m(\u001b[38;5;28mself\u001b[39m, primitive, tracers, params):\n\u001b[0;32m--> 728\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/prng.py:582\u001b[0m, in \u001b[0;36mrandom_seed_impl\u001b[0;34m(seeds, impl)\u001b[0m\n\u001b[1;32m    580\u001b[0m \u001b[38;5;129m@random_seed_p\u001b[39m\u001b[38;5;241m.\u001b[39mdef_impl\n\u001b[1;32m    581\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrandom_seed_impl\u001b[39m(seeds, \u001b[38;5;241m*\u001b[39m, impl):\n\u001b[0;32m--> 582\u001b[0m   base_arr \u001b[38;5;241m=\u001b[39m \u001b[43mrandom_seed_impl_base\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseeds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    583\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m PRNGKeyArray(impl, base_arr)\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/prng.py:587\u001b[0m, in \u001b[0;36mrandom_seed_impl_base\u001b[0;34m(seeds, impl)\u001b[0m\n\u001b[1;32m    585\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrandom_seed_impl_base\u001b[39m(seeds, \u001b[38;5;241m*\u001b[39m, impl):\n\u001b[1;32m    586\u001b[0m   seed \u001b[38;5;241m=\u001b[39m iterated_vmap_unary(seeds\u001b[38;5;241m.\u001b[39mndim, impl\u001b[38;5;241m.\u001b[39mseed)\n\u001b[0;32m--> 587\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mseed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseeds\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/prng.py:822\u001b[0m, in \u001b[0;36mthreefry_seed\u001b[0;34m(seed)\u001b[0m\n\u001b[1;32m    819\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPRNG key seed must be an integer; got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mseed\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    820\u001b[0m convert \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k: lax\u001b[38;5;241m.\u001b[39mreshape(lax\u001b[38;5;241m.\u001b[39mconvert_element_type(k, np\u001b[38;5;241m.\u001b[39muint32), [\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m    821\u001b[0m k1 \u001b[38;5;241m=\u001b[39m convert(\n\u001b[0;32m--> 822\u001b[0m     \u001b[43mlax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshift_right_logical\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlax_internal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_const\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    823\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m jax\u001b[38;5;241m.\u001b[39mnumpy_dtype_promotion(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstandard\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m    824\u001b[0m   \u001b[38;5;66;03m# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit\u001b[39;00m\n\u001b[1;32m    825\u001b[0m   \u001b[38;5;66;03m# inputs. We should avoid this.\u001b[39;00m\n\u001b[1;32m    826\u001b[0m   k2 \u001b[38;5;241m=\u001b[39m convert(jnp\u001b[38;5;241m.\u001b[39mbitwise_and(seed, np\u001b[38;5;241m.\u001b[39muint32(\u001b[38;5;241m0xFFFFFFFF\u001b[39m)))\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py:511\u001b[0m, in \u001b[0;36mshift_right_logical\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m    509\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mshift_right_logical\u001b[39m(x: ArrayLike, y: ArrayLike) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Array:\n\u001b[1;32m    510\u001b[0m \u001b[38;5;250m  \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Elementwise logical right shift: :math:`x \\gg y`.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 511\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mshift_right_logical_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:343\u001b[0m, in \u001b[0;36mPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m    340\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams):\n\u001b[1;32m    341\u001b[0m   \u001b[38;5;28;01massert\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m config\u001b[38;5;241m.\u001b[39mjax_enable_checks \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m    342\u001b[0m           \u001b[38;5;28mall\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(arg, Tracer) \u001b[38;5;129;01mor\u001b[39;00m valid_jaxtype(arg) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)), args\n\u001b[0;32m--> 343\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfind_top_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:346\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m    345\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind_with_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, trace, args, params):\n\u001b[0;32m--> 346\u001b[0m   out \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull_raise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    347\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(full_lower, out) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmultiple_results \u001b[38;5;28;01melse\u001b[39;00m full_lower(out)\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/core.py:728\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m    727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprocess_primitive\u001b[39m(\u001b[38;5;28mself\u001b[39m, primitive, tracers, params):\n\u001b[0;32m--> 728\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/dispatch.py:122\u001b[0m, in \u001b[0;36mapply_primitive\u001b[0;34m(prim, *args, **params)\u001b[0m\n\u001b[1;32m    120\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply_primitive\u001b[39m(prim, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams):\n\u001b[1;32m    121\u001b[0m \u001b[38;5;250m  \u001b[39m\u001b[38;5;124;03m\"\"\"Impl rule that compiles and runs a single primitive 'prim' using XLA.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m   compiled_fun \u001b[38;5;241m=\u001b[39m \u001b[43mxla_primitive_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43munsafe_map\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg_spec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    123\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    124\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m compiled_fun(\u001b[38;5;241m*\u001b[39margs)\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/util.py:253\u001b[0m, in \u001b[0;36mcache.<locals>.wrap.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    251\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m f(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    252\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_trace_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/util.py:246\u001b[0m, in \u001b[0;36mcache.<locals>.wrap.<locals>.cached\u001b[0;34m(_, *args, **kwargs)\u001b[0m\n\u001b[1;32m    244\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mlru_cache(max_size)\n\u001b[1;32m    245\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcached\u001b[39m(_, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 246\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/dispatch.py:201\u001b[0m, in \u001b[0;36mxla_primitive_callable\u001b[0;34m(prim, *arg_specs, **params)\u001b[0m\n\u001b[1;32m    199\u001b[0m   \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    200\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m out,\n\u001b[0;32m--> 201\u001b[0m compiled \u001b[38;5;241m=\u001b[39m \u001b[43m_xla_callable_uncached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlu\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrap_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprim_fun\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m                                  \u001b[49m\u001b[43mprim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43marg_specs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m prim\u001b[38;5;241m.\u001b[39mmultiple_results:\n\u001b[1;32m    204\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw: compiled(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)[\u001b[38;5;241m0\u001b[39m]\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/dispatch.py:354\u001b[0m, in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)\u001b[0m\n\u001b[1;32m    351\u001b[0m   computation \u001b[38;5;241m=\u001b[39m sharded_lowering(fun, device, backend, name, donated_invars,\n\u001b[1;32m    352\u001b[0m                                  \u001b[38;5;28;01mFalse\u001b[39;00m, keep_unused, \u001b[38;5;241m*\u001b[39marg_specs)\n\u001b[1;32m    353\u001b[0m   allow_prop \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;01mTrue\u001b[39;00m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(computation\u001b[38;5;241m.\u001b[39mcompile_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mglobal_out_avals\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m--> 354\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcomputation\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_allow_propagation_to_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_prop\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39munsafe_call\n\u001b[1;32m    355\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    356\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m lower_xla_callable(fun, device, backend, name, donated_invars, \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    357\u001b[0m                             keep_unused, \u001b[38;5;241m*\u001b[39marg_specs)\u001b[38;5;241m.\u001b[39mcompile()\u001b[38;5;241m.\u001b[39munsafe_call\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py:3203\u001b[0m, in \u001b[0;36mMeshComputation.compile\u001b[0;34m(self, _allow_propagation_to_outputs, _allow_compile_replicated)\u001b[0m\n\u001b[1;32m   3199\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompile\u001b[39m(\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m   3200\u001b[0m             _allow_propagation_to_outputs: Optional[Sequence[\u001b[38;5;28mbool\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   3201\u001b[0m             _allow_compile_replicated: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m MeshExecutable:\n\u001b[1;32m   3202\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_executable \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 3203\u001b[0m     executable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compile_unloaded\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3204\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_allow_propagation_to_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_allow_compile_replicated\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3205\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(executable, UnloadedMeshExecutable):\n\u001b[1;32m   3206\u001b[0m       executable \u001b[38;5;241m=\u001b[39m executable\u001b[38;5;241m.\u001b[39mload()\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py:3174\u001b[0m, in \u001b[0;36mMeshComputation._compile_unloaded\u001b[0;34m(self, _allow_propagation_to_outputs, _allow_compile_replicated)\u001b[0m\n\u001b[1;32m   3172\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m MeshExecutable\u001b[38;5;241m.\u001b[39mfrom_trivial_jaxpr(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompile_args)\n\u001b[1;32m   3173\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3174\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mUnloadedMeshExecutable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_hlo\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3175\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3176\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_hlo\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3177\u001b[0m \u001b[43m      \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3178\u001b[0m \u001b[43m      \u001b[49m\u001b[43m_allow_propagation_to_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_allow_propagation_to_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3179\u001b[0m \u001b[43m      \u001b[49m\u001b[43m_allow_compile_replicated\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_allow_compile_replicated\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py:3456\u001b[0m, in \u001b[0;36mUnloadedMeshExecutable.from_hlo\u001b[0;34m(name, computation, mesh, global_in_avals, global_out_avals, in_shardings, out_shardings, spmd_lowering, tuple_args, in_is_global, auto_spmd_lowering, _allow_propagation_to_outputs, _allow_compile_replicated, unordered_effects, ordered_effects, host_callbacks, keepalive, kept_var_idx, backend, device_assignment, committed, pmap_nreps)\u001b[0m\n\u001b[1;32m   3452\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3453\u001b[0m   \u001b[38;5;28;01mwith\u001b[39;00m dispatch\u001b[38;5;241m.\u001b[39mlog_elapsed_time(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFinished XLA compilation of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   3454\u001b[0m                                  \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124min \u001b[39m\u001b[38;5;132;01m{elapsed_time}\u001b[39;00m\u001b[38;5;124m sec\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m   3455\u001b[0m                                  event\u001b[38;5;241m=\u001b[39mdispatch\u001b[38;5;241m.\u001b[39mBACKEND_COMPILE_EVENT):\n\u001b[0;32m-> 3456\u001b[0m     xla_executable \u001b[38;5;241m=\u001b[39m \u001b[43mdispatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile_or_get_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3457\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3459\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m auto_spmd_lowering:\n\u001b[1;32m   3460\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m mesh \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/dispatch.py:1081\u001b[0m, in \u001b[0;36mcompile_or_get_cached\u001b[0;34m(backend, computation, compile_options, host_callbacks)\u001b[0m\n\u001b[1;32m   1077\u001b[0m     _cache_write(serialized_computation, compile_time, module_name,\n\u001b[1;32m   1078\u001b[0m                  compile_options, backend, compiled, host_callbacks)\n\u001b[1;32m   1079\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m compiled\n\u001b[0;32m-> 1081\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackend_compile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mserialized_computation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1082\u001b[0m \u001b[43m                       \u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/profiler.py:314\u001b[0m, in \u001b[0;36mannotate_function.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m    312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    313\u001b[0m   \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 314\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    315\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m wrapper\n",
      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/jax/_src/dispatch.py:1026\u001b[0m, in \u001b[0;36mbackend_compile\u001b[0;34m(backend, built_c, options, host_callbacks)\u001b[0m\n\u001b[1;32m   1021\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m backend\u001b[38;5;241m.\u001b[39mcompile(built_c, compile_options\u001b[38;5;241m=\u001b[39moptions,\n\u001b[1;32m   1022\u001b[0m                          host_callbacks\u001b[38;5;241m=\u001b[39mhost_callbacks)\n\u001b[1;32m   1023\u001b[0m \u001b[38;5;66;03m# Some backends don't have `host_callbacks` option yet\u001b[39;00m\n\u001b[1;32m   1024\u001b[0m \u001b[38;5;66;03m# TODO(sharadmv): remove this fallback when all backends allow `compile`\u001b[39;00m\n\u001b[1;32m   1025\u001b[0m \u001b[38;5;66;03m# to take in `host_callbacks`\u001b[39;00m\n\u001b[0;32m-> 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackend\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuilt_c\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mXlaRuntimeError\u001b[0m: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:626) dnn != nullptr "
     ]
    }
   ],
   "source": [
    "n_train = 3\n",
    "n_test = 1\n",
    "d = 100 \n",
    "key1, key2, key3 = random.split(random.PRNGKey(1), num=3)\n",
    "output_shape, params = init_fn(key1, input_shape=(d,))\n",
    "delta = jax.numpy.sqrt(d/n_train) + 10\n",
    "cov = jax.numpy.identity(d, dtype=None)\n",
    "cov = cov.at[0,0].set(1+delta)\n",
    "mean = jax.numpy.zeros(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5eef69",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = jax.random.multivariate_normal(key2, mean, cov, (n_train,))\n",
    "y_train = random.uniform(key2, shape=(n_train, 1))  # training targets\n",
    "x_test = jax.random.multivariate_normal(key3, mean, cov, (n_test,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6fcd15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def returnsNTkKernel(x_train, kernel_fn, x=None, train=True):\n",
    "    if train:\n",
    "        ntk_train_train = kernel_fn(x_train, x_train, 'ntk')\n",
    "        return ntk_train_train\n",
    "    ntk_test_train = kernel_fn(x, x_train, 'ntk')\n",
    "    return ntk_test_train\n",
    "\n",
    "def EigenDecompositonOfKernel(ntk_train_train):\n",
    "    eigenvalues, eigenvectors = jax.numpy.linalg.eigh(ntk_train_train)\n",
    "    return eigenvalues, eigenvectors\n",
    "\n",
    "def feature(y_train, ntk_test_train, eigenvalues, eigenvectors, i):\n",
    "    lambda_i, v_i = eigenvalues[i], eigenvectors[i].reshape(eigenvectors[i].shape[0],1)\n",
    "    output = ntk_test_train @ v_i @ v_i.T @ y_train\n",
    "    print(output, lambda_i, v_i.T@v_i)\n",
    "    output = output/lambda_i\n",
    "    return output\n",
    "\n",
    "def returnsXPerturbed(x_test, epsilon, key):\n",
    "    noise = random.normal(key, x_test.shape)\n",
    "    noise_norm = jax.numpy.linalg.norm(noise, axis=1)\n",
    "    noise_norm = noise_norm.reshape(noise_norm.shape[0],1)\n",
    "    normed_noise = epsilon*noise/noise_norm\n",
    "    x_perturbed = x_test+normed_noise\n",
    "    return x_perturbed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a48b55",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk_train_train = returnsNTkKernel(x_train, kernel_fn, x=None, train=True)\n",
    "eigenvalues, eigenvectors = EigenDecompositonOfKernel(ntk_train_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a571224e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov_eigenvalues, cov_eigenvectors = EigenDecompositonOfKernel((1/d)*x_train@x_train.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ac067a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "norms={}\n",
    "ntk_test_train_clean = returnsNTkKernel(x_train, kernel_fn, x_test, train=False)\n",
    "for epsilon in [0.5, 0.7, 1, 2]:\n",
    "    norms[str(epsilon)] = []\n",
    "    x_perturbed = returnsXPerturbed(x_test, epsilon, key3)\n",
    "    ntk_test_train_perturbed = returnsNTkKernel(x_train, kernel_fn, x_perturbed, train=False)   \n",
    "    fx_delta = lambda x: abs(feature(y_train, ntk_test_train_perturbed, eigenvalues, eigenvectors, x))\n",
    "    fx_clean = lambda x: abs(feature(y_train, ntk_test_train_clean, eigenvalues, eigenvectors, x))\n",
    "    output_delta = map(fx_delta, range(n_train))\n",
    "    output_clean = map(fx_clean, range(n_train))\n",
    "    difference =  jax.numpy.array(list(output_delta))-jax.numpy.array(list(output_clean))\n",
    "    norms[str(epsilon)] = difference.reshape(difference.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bebb977",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "fig = plt.figure(figsize=(15,5))\n",
    "gs = matplotlib.gridspec.GridSpec(1, 2, width_ratios=[2, 2]) \n",
    "ax1 = plt.subplot(gs[0])\n",
    "ax2 = plt.subplot(gs[1])\n",
    "for epsilon in norms.keys():\n",
    "    ax1.plot(eigenvalues, abs(norms[epsilon]), label=epsilon)\n",
    "ax2.plot(eigenvalues, marker='+', label=\"kernel\")\n",
    "ax2.plot(cov_eigenvalues, marker='.', label=\"cov\")\n",
    "ax1.set_ylabel(\"||f_x_perturbed-f_x||\")\n",
    "ax1.set_xlabel(\"Eigenvalues\")\n",
    "ax2.set_ylabel(\"Eigenvalues\")\n",
    "#ax1.set_yscale(\"log\")\n",
    "ax1.legend()\n",
    "ax2.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6591ab19",
   "metadata": {},
   "outputs": [],
   "source": [
    "eigenvalues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2039f178",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found existing installation: neural-tangents 0.6.2\n",
      "Uninstalling neural-tangents-0.6.2:\n",
      "  Successfully uninstalled neural-tangents-0.6.2\n"
     ]
    }
   ],
   "source": [
    "!pip uninstall neural_tangents --yes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7c05c7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
