{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9f8c774-389c-4719-9f8d-7de70d9488ec",
   "metadata": {},
   "source": [
    "# NTK Results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c758fc2c-ee0b-469c-a4be-348ec3ffe185",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39b38f2-103b-4f71-9c9a-b8eada4b1706",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d786cc46-d82d-4aa2-8071-60abcf228a09",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src.utils import *\n",
    "from src.functions import *\n",
    "from src.pdes import *\n",
    "from src.ntk import *\n",
    "\n",
    "from flax import nnx\n",
    "import optax\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5033a2f6-a602-4c8a-b70f-2377632c2e16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "cb2ff320-751b-4432-834e-79a99b03996d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Function Fitting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff46d037-c901-4f99-8319-67e57e6515ab",
   "metadata": {},
   "source": [
    "We first perform experiments relevant to the NTK for the Function Fitting case, because PDEs have their own NTK formulation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc94f52-3850-4153-9bda-431a9d01cf89",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8398a49e-0f58-438c-aab5-454e4803894d",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 5000\n",
    "n_ntk = 256\n",
    "\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 2001\n",
    "checkpoints = [0, 500, 1000, 1500, 2000]\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "# Model input/output\n",
    "n_in, n_out = 2, 1\n",
    "\n",
    "# Studied functions\n",
    "funcs = [(\"f1\", f1), (\"f2\", f2), (\"f3\", f3), (\"f4\", f4), (\"f5\", f5)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8e6c1d-8ea2-4e1f-ba61-3df4359d16bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb63ad3a-45f3-4773-b75d-1388d5f4c1a2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7322cbe5-12ac-451c-927f-b83a9a243f87",
   "metadata": {},
   "source": [
    "### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aabd316a-31e0-4e45-9867-63827c56cc86",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(func, N, n_ntk, seed):\n",
    "    \n",
    "    # Generate data\n",
    "    x, y = generate_func_data(func, 2, N, seed)\n",
    "    \n",
    "    # Split data (at this point just to ensure continuity)\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "    \n",
    "    # Subsample points used to compute NTK\n",
    "    key_ntk = jax.random.PRNGKey(seed)\n",
    "    idx = jax.random.choice(key_ntk, X_train.shape[0], shape=(n_ntk,), replace=False)\n",
    "    X_ntk = X_train[idx]\n",
    "\n",
    "    return X_train, y_train, X_ntk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "643f2354-d8b3-45ed-b50f-347fa278a4a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(func, model, opt, X_train, y_train, X_ntk, title):\n",
    "    spec_list, tau_list = [], []\n",
    "    conds, ranks = [], []\n",
    "\n",
    "    # τ = 0 (before any updates)\n",
    "    K0 = stabilize_kernel(ntk_matrix(model, X_ntk))\n",
    "    lam0 = jnp.sort(jnp.linalg.eigvalsh(K0))[::-1]\n",
    "    spec_list.append(lam0)\n",
    "    tau_list.append(0)\n",
    "    conds.append(cond_from_eigs(lam0))\n",
    "    \n",
    "    eff_rank0 = (lam0.sum() ** 2) / (jnp.sum(lam0 ** 2) + 1e-12)\n",
    "    ranks.append(float(eff_rank0))\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        loss = func_fit_step(model, opt, X_train, y_train)\n",
    "\n",
    "        if epoch in checkpoints[1:]:\n",
    "            Kt = stabilize_kernel(ntk_matrix(model, X_ntk))\n",
    "            lam = jnp.sort(jnp.linalg.eigvalsh(Kt))[::-1]\n",
    "            spec_list.append(lam)\n",
    "            tau_list.append(epoch)\n",
    "            conds.append(cond_from_eigs(lam))\n",
    "            \n",
    "            eff_rank_t = (lam.sum() ** 2) / (jnp.sum(lam ** 2) + 1e-12)\n",
    "            ranks.append(float(eff_rank_t))\n",
    "\n",
    "    l2error = func_fit_eval(model, func, 2, 200)\n",
    "\n",
    "    print(f\"\\t{title} Model Metrics:\")\n",
    "    print(f\"\\tCond Number: τ=0 → {conds[0]:.2e}, τ={tau_list[-1]} → {conds[-1]:.2e}\")\n",
    "    print(f\"\\tEffective Rank: τ=0 → {ranks[0]:.2f}, τ={tau_list[-1]} → {ranks[-1]:.2f}\")\n",
    "    print(f\"\\tFinal Loss = {loss:.2e}\\t L^2 Error = {l2error:.2e}\\n\")\n",
    "\n",
    "    return spec_list, tau_list, conds, ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aab4a5c-3ba8-4da2-b406-3896cb8e2c19",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "362150b7-bbd5-43e8-801a-23d056e14081",
   "metadata": {},
   "source": [
    "### Main Routine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e81bec26-7d27-4a32-a16c-f219cd4b0258",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dict()\n",
    "\n",
    "for func_name, func in funcs:\n",
    "\n",
    "    results[func_name] = dict()\n",
    "\n",
    "    # Get the data for the function\n",
    "    X_train, y_train, X_ntk = get_data(func, N, n_ntk, seed)\n",
    "\n",
    "    # Define the small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    results[func_name][\"small\"] = dict()\n",
    "    \n",
    "    print(f\"Training model with dimensions {layer_dims} for function {func_name}.\")\n",
    "    \n",
    "    # Baseline\n",
    "    results[func_name][\"small\"][\"Baseline\"] = dict()\n",
    "    \n",
    "    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+1)\n",
    "    base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, base_model, base_opt, X_train, y_train, X_ntk, \"Baseline\")\n",
    "    results[func_name][\"small\"][\"Baseline\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"small\"][\"Baseline\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"small\"][\"Baseline\"][\"conds\"] = conds\n",
    "    results[func_name][\"small\"][\"Baseline\"][\"ranks\"] = ranks\n",
    "\n",
    "    # Glorot\n",
    "    results[func_name][\"small\"][\"Glorot\"] = dict()\n",
    "    \n",
    "    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+1)\n",
    "    glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, glorot_model, glorot_opt, X_train, y_train, X_ntk, \"Glorot\")\n",
    "    results[func_name][\"small\"][\"Glorot\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"small\"][\"Glorot\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"small\"][\"Glorot\"][\"conds\"] = conds\n",
    "    results[func_name][\"small\"][\"Glorot\"][\"ranks\"] = ranks\n",
    "\n",
    "    # Power Law\n",
    "    results[func_name][\"small\"][\"Power\"] = dict()\n",
    "    \n",
    "    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+1)\n",
    "    power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, power_model, power_opt, X_train, y_train, X_ntk, \"Power\")\n",
    "    results[func_name][\"small\"][\"Power\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"small\"][\"Power\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"small\"][\"Power\"][\"conds\"] = conds\n",
    "    results[func_name][\"small\"][\"Power\"][\"ranks\"] = ranks\n",
    "\n",
    "\n",
    "    # Define the big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    results[func_name][\"big\"] = dict()\n",
    "\n",
    "    print(f\"Training model with dimensions {layer_dims} for function {func_name}.\")\n",
    "\n",
    "    # Baseline\n",
    "    results[func_name][\"big\"][\"Baseline\"] = dict()\n",
    "    \n",
    "    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+1)\n",
    "    base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, base_model, base_opt, X_train, y_train, X_ntk, \"Baseline\")\n",
    "    results[func_name][\"big\"][\"Baseline\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"big\"][\"Baseline\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"big\"][\"Baseline\"][\"conds\"] = conds\n",
    "    results[func_name][\"big\"][\"Baseline\"][\"ranks\"] = ranks\n",
    "\n",
    "    # Glorot\n",
    "    results[func_name][\"big\"][\"Glorot\"] = dict()\n",
    "    \n",
    "    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+1)\n",
    "    glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, glorot_model, glorot_opt, X_train, y_train, X_ntk, \"Glorot\")\n",
    "    results[func_name][\"big\"][\"Glorot\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"big\"][\"Glorot\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"big\"][\"Glorot\"][\"conds\"] = conds\n",
    "    results[func_name][\"big\"][\"Glorot\"][\"ranks\"] = ranks\n",
    "\n",
    "    # Power Law\n",
    "    results[func_name][\"big\"][\"Power\"] = dict()\n",
    "    \n",
    "    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+1)\n",
    "    power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "    spec_list, tau_list, conds, ranks = run_experiment(func, power_model, power_opt, X_train, y_train, X_ntk, \"Power\")\n",
    "    results[func_name][\"big\"][\"Power\"][\"spec_list\"] = spec_list\n",
    "    results[func_name][\"big\"][\"Power\"][\"tau_list\"] = tau_list\n",
    "    results[func_name][\"big\"][\"Power\"][\"conds\"] = conds\n",
    "    results[func_name][\"big\"][\"Power\"][\"ranks\"] = ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f680a60b-1a5d-4dd0-933d-87584e82ef13",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results for further processing\n",
    "results_dir = 'ff_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"ntk.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08648145-42a3-4532-adc9-a8f10211d8f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fa8c0629-1b31-42d9-92b9-9ad1459a3f52",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## PDE Solving"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44300211-8480-4a03-9958-a851b2149e76",
   "metadata": {},
   "source": [
    "We then expand these ideas to PDEs solved using the PIKAN framework with RBA."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2674420-a76d-4c26-bf84-97c39ebff216",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bfbb22b-c61b-449d-947c-d9a798f5c06e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "pdes = [(\"allen-cahn\", ac_res), (\"burgers\", burgers_res), (\"helmholtz\", helmholtz_res)]\n",
    "\n",
    "N = 2**6\n",
    "n_ntk_pde = 256\n",
    "n_ntk_bc = 32\n",
    "\n",
    "RBA_gamma = 0.999\n",
    "RBA_eta = 0.01\n",
    "\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 5001\n",
    "checkpoints = [0, 1000, 2000, 3000, 4000, 5000]\n",
    "\n",
    "n_in, n_out = 2, 1\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "\n",
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1738e70d-8fa1-41b5-b5fc-24ea2d4dcdcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "13e41c54-f80e-42eb-bb0b-c2ce80f4d954",
   "metadata": {},
   "source": [
    "### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1764dd9d-5e0d-478a-8faa-ec68105daf83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(pde_name, N, n_ntk_pde=256, n_ntk_bc=32, seed=42):\n",
    "    \n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N)\n",
    "    \n",
    "    # consistent NTK subsets per experiment\n",
    "    key = jax.random.PRNGKey(seed)\n",
    "    \n",
    "    idx_pde = jax.random.choice(key, pde_collocs.shape[0], shape=(min(n_ntk_pde, pde_collocs.shape[0]),), replace=False)\n",
    "    \n",
    "    idx_bc  = jax.random.choice(key,  bc_collocs.shape[0], shape=(min(n_ntk_bc,  bc_collocs.shape[0]),),  replace=False)\n",
    "\n",
    "    return refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9650b9be-fced-48cc-99ae-f60fa114d714",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment_pde(pde_res_fn, model, opt, refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc, title):\n",
    "\n",
    "    # NTK X, Y\n",
    "    X_pde_ntk = pde_collocs[idx_pde]\n",
    "    \n",
    "    X_bc_ntk  = bc_collocs[idx_bc]\n",
    "    Y_bc_ntk  = bc_data[idx_bc]\n",
    "\n",
    "    # init RBA weights for training loop\n",
    "    l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "    l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "\n",
    "    specE_list, specB_list, tau_list = [], [], []\n",
    "    condsE, condsB = [], []\n",
    "\n",
    "    # τ = 0\n",
    "    wE0 = l_E[idx_pde].ravel()\n",
    "    wB0 = l_B[idx_bc].ravel()\n",
    "    lamE0, lamB0 = pinntk_diag_spectra_weighted(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk, wE0, wB0)\n",
    "    \n",
    "    specE_list.append(lamE0)\n",
    "    specB_list.append(lamB0)\n",
    "    tau_list.append(0)\n",
    "    \n",
    "    condsE.append(cond_from_eigs(lamE0))\n",
    "    condsB.append(cond_from_eigs(lamB0))\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        loss, l_E, l_B = train_step(model, opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "\n",
    "        if epoch in checkpoints[1:]:\n",
    "            wEt = l_E[idx_pde].ravel()\n",
    "            wBt = l_B[idx_bc].ravel()\n",
    "            lamE, lamB = pinntk_diag_spectra_weighted(model, pde_res, X_pde_ntk, X_bc_ntk, Y_bc_ntk, wEt, wBt)\n",
    "            \n",
    "            specE_list.append(lamE)\n",
    "            specB_list.append(lamB)\n",
    "            tau_list.append(epoch)\n",
    "            \n",
    "            condsE.append(cond_from_eigs(lamE))\n",
    "            condsB.append(cond_from_eigs(lamB))\n",
    "\n",
    "    output = model(coords).reshape(refsol.shape)\n",
    "    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "\n",
    "    print(f\"\\t{title} Metrics:\")\n",
    "    print(f\"\\tPDE Cond#: τ=0 → {condsE[0]:.2e}, τ={tau_list[-1]} → {condsE[-1]:.2e}\")\n",
    "    print(f\"\\tBC  Cond#: τ=0 → {condsB[0]:.2e}, τ={tau_list[-1]} → {condsB[-1]:.2e}\")\n",
    "    print(f\"\\tFinal Loss = {loss:.2e}\\t L^2 Error = {l2error:.2e}\\n\")\n",
    "\n",
    "    return specE_list, specB_list, tau_list, condsE, condsB\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a701982-e645-454e-a409-efdd4a0f1fcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b4b5b919-5adc-4ece-9a6c-0ffdc0a18724",
   "metadata": {},
   "source": [
    "### Main Routine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7796def-49c4-4971-92e8-fa94f92d1937",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dict()\n",
    "\n",
    "for pde_name, pde_res in pdes:\n",
    "\n",
    "    results[pde_name] = dict()\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the data\n",
    "    refsol, coords, pde_collocs, bc_collocs, bc_data, idx_pde, idx_bc = get_data(pde_name, N, n_ntk_pde, n_ntk_bc, seed)\n",
    "\n",
    "    \"\"\"\n",
    "    # Define the small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    results[pde_name][\"small\"] = dict()\n",
    "    \n",
    "    print(f\"Training model with dimensions {layer_dims} for PDE {pde_name}.\")\n",
    "    \n",
    "    # Baseline\n",
    "    results[pde_name][\"small\"][\"Baseline\"] = dict()\n",
    "    \n",
    "    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+1)\n",
    "    base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, base_model, base_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Baseline\")\n",
    "    \n",
    "    results[pde_name][\"small\"][\"Baseline\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"small\"][\"Baseline\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"small\"][\"Baseline\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"small\"][\"Baseline\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"small\"][\"Baseline\"][\"condsB\"] = condsB\n",
    "\n",
    "    # Glorot\n",
    "    results[pde_name][\"small\"][\"Glorot\"] = dict()\n",
    "    \n",
    "    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+1)\n",
    "    glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, glorot_model, glorot_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Glorot\")\n",
    "    \n",
    "    results[pde_name][\"small\"][\"Glorot\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"small\"][\"Glorot\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"small\"][\"Glorot\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"small\"][\"Glorot\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"small\"][\"Glorot\"][\"condsB\"] = condsB\n",
    "\n",
    "    # Power Law\n",
    "    results[pde_name][\"small\"][\"Power\"] = dict()\n",
    "    \n",
    "    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+1)\n",
    "    power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, power_model, power_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Power\")\n",
    "    \n",
    "    results[pde_name][\"small\"][\"Power\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"small\"][\"Power\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"small\"][\"Power\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"small\"][\"Power\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"small\"][\"Power\"][\"condsB\"] = condsB\n",
    "    \"\"\"\n",
    "\n",
    "\n",
    "    # Define the big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    results[pde_name][\"big\"] = dict()\n",
    "\n",
    "    print(f\"Training model with dimensions {layer_dims} for PDE {pde_name}.\")\n",
    "\n",
    "    # Baseline\n",
    "    results[pde_name][\"big\"][\"Baseline\"] = dict()\n",
    "    \n",
    "    base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+1)\n",
    "    base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, base_model, base_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Baseline\")\n",
    "    \n",
    "    results[pde_name][\"big\"][\"Baseline\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"big\"][\"Baseline\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"big\"][\"Baseline\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"big\"][\"Baseline\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"big\"][\"Baseline\"][\"condsB\"] = condsB\n",
    "\n",
    "    # Glorot\n",
    "    results[pde_name][\"big\"][\"Glorot\"] = dict()\n",
    "    \n",
    "    glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+1)\n",
    "    glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, glorot_model, glorot_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Glorot\")\n",
    "    \n",
    "    results[pde_name][\"big\"][\"Glorot\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"big\"][\"Glorot\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"big\"][\"Glorot\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"big\"][\"Glorot\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"big\"][\"Glorot\"][\"condsB\"] = condsB\n",
    "\n",
    "    # Power Law\n",
    "    results[pde_name][\"big\"][\"Power\"] = dict()\n",
    "    \n",
    "    power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+1)\n",
    "    power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "    specE_list, specB_list, tau_list, condsE, condsB = run_experiment_pde(pde_res, power_model, power_opt, refsol, coords, pde_collocs, \n",
    "                                                                          bc_collocs, bc_data, idx_pde, idx_bc, \"Power\")\n",
    "    \n",
    "    results[pde_name][\"big\"][\"Power\"][\"specE_list\"] = specE_list\n",
    "    results[pde_name][\"big\"][\"Power\"][\"specB_list\"] = specB_list\n",
    "    results[pde_name][\"big\"][\"Power\"][\"tau_list\"] = tau_list\n",
    "    results[pde_name][\"big\"][\"Power\"][\"condsE\"] = condsE\n",
    "    results[pde_name][\"big\"][\"Power\"][\"condsB\"] = condsB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18c94652-9138-4d4b-bd8b-63a2c12e05f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results for further processing\n",
    "results_dir = 'pde_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"ntk.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "168cbf82-7e5e-42f4-ba0a-8ef9ae38dde8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
