{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9f8c774-389c-4719-9f8d-7de70d9488ec",
   "metadata": {},
   "source": [
    "# An Empirical Investigation of Initialization Strategies for Kolmogorov–Arnold Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c614efc-b2f9-49e5-ad4d-aeb2ddaac4dc",
   "metadata": {},
   "source": [
    "The present Jupyter Notebook serves as the companion to the ICML25 MOSS submission titled \"*An Empirical Investigation of Initialization Strategies for Kolmogorov–Arnold Networks*\". The notebook includes the code that reproduces the results shown in the paper along with the corresponding plots. Note that the results shown in Table 1 of the manuscript (see [this](#Grid-Search-Results---Data-Analysis) notebook section) use the `grid_search.csv` file as a reference, which is obtained after running a grid search over multiple architectures. This file is provided in the Supplementary Material of the submission and is required to obtain the statistical results shown in Table 1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c758fc2c-ee0b-469c-a4be-348ec3ffe185",
   "metadata": {},
   "source": [
    "## Preliminaries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66a225d4-af8b-449b-aba3-593ac4e7ddf5",
   "metadata": {},
   "source": [
    "In the following, we install necessary packages and define preliminaries that are essential for the main part of the code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0547805-598a-4f2f-ab9b-2a5f7e713617",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q jaxkan[gpu]\n",
    "!pip install -q scikit-learn\n",
    "!pip install -q pandas\n",
    "!pip install -q matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39b38f2-103b-4f71-9c9a-b8eada4b1706",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d786cc46-d82d-4aa2-8071-60abcf228a09",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union, List\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax.scipy.special import i1, i1e, fresnel, erfinv, erf\n",
    "\n",
    "from flax import nnx\n",
    "import optax\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "from jaxkan.layers.Spline import SplineLayer\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import matplotlib.lines as mlines\n",
    "from matplotlib.cm import get_cmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8132ecf-ed47-43eb-8783-c87d7a5a0dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the functions used for the task\n",
    "def f1(x):\n",
    "    return x[:, [0]] * x[:, [1]]\n",
    "\n",
    "def f2(x):\n",
    "    return jnp.exp(jnp.sin(jnp.pi * x[:, [0]]) + x[:, [1]]**2)\n",
    "\n",
    "def f3(x):\n",
    "    return i1(x[:, [0]]) + jnp.exp(i1e(x[:, [1]])) + jnp.sin(x[:, [0]] * x[:, [1]])\n",
    "\n",
    "def f4(x):\n",
    "    S, C = fresnel(f3(x) + erfinv(x[:, [1]]))\n",
    "    return S * C\n",
    "\n",
    "def f5(x):\n",
    "    return x[:, 1].reshape(-1, 1) * jnp.where(x[:, 0] < 0.5, 1, -1).reshape(-1, 1) + erf(x[:, 0]).reshape(-1, 1) * jnp.where(f1(x) < 1, f1(x), 1/f1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb26ab1d-09a8-4799-9c4d-1c1938332c92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Classes that define the architecture that allows to get lecun-normalized initialization\n",
    "class StdSplineLayer(SplineLayer):\n",
    "\n",
    "    def _initialize_params(self, init_scheme, seed):\n",
    "\n",
    "        key = jax.random.key(seed)\n",
    "\n",
    "        # Also get distribution type\n",
    "        distrib = init_scheme.get(\"distribution\", \"uniform\")\n",
    "\n",
    "        if distrib is None:\n",
    "            distrib = \"uniform\"\n",
    "\n",
    "        # Generate a sample of 10^5 points\n",
    "        if distrib == \"uniform\":\n",
    "            sample = jax.random.uniform(key, shape=(100000,), minval=-1.0, maxval=1.0)\n",
    "        elif distrib == \"normal\":\n",
    "            sample = jax.random.normal(key, shape=(100000,))\n",
    "\n",
    "        # Finally get gain\n",
    "        gain = init_scheme.get(\"gain\", None)\n",
    "        if gain is None:\n",
    "            gain = sample.std().item()\n",
    "\n",
    "        # ---- Residual Calculations --------\n",
    "        # Variance equipartitioned across all terms\n",
    "        scale = self.n_in * (self.grid.G + self.k + 1)\n",
    "        # Apply the residual function\n",
    "        y_res = self.residual(sample)\n",
    "        # Calculate the average of residual^2(x)\n",
    "        y_res_sq = y_res**2\n",
    "        y_res_sq_mean = y_res_sq.mean().item()\n",
    "\n",
    "        std_res = gain/jnp.sqrt(scale*y_res_sq_mean)\n",
    "        c_res = nnx.initializers.normal(stddev=std_res)(self.rngs.params(), (self.n_out, self.n_in), jnp.float32)\n",
    "\n",
    "        # ---- Basis Calculations -----------\n",
    "        std_b = gain/jnp.sqrt(scale)\n",
    "        c_basis = nnx.initializers.normal(stddev=std_b)(\n",
    "            self.rngs.params(), (self.n_out, self.n_in, self.grid.G + self.k), jnp.float32\n",
    "        )\n",
    "        \n",
    "        return c_res, c_basis\n",
    "\n",
    "        \n",
    "    def basis(self, x):\n",
    "        basis_splines = super().basis(x)\n",
    "\n",
    "        mean = jnp.mean(basis_splines, axis=0, keepdims=True)\n",
    "        denom = jnp.sqrt(jnp.var(basis_splines, axis=0, keepdims=True) + 1e-5)\n",
    "        basis_splines = (basis_splines - mean) / denom\n",
    "\n",
    "        return basis_splines\n",
    "\n",
    "\n",
    "class StdKAN(nnx.Module):\n",
    "    \n",
    "    def __init__(self, layer_dims: List[int], required_parameters: Union[None, dict] = None, seed: int = 42):\n",
    "            \n",
    "        if required_parameters is None:\n",
    "            raise ValueError(\"required_parameters must be provided as a dictionary for the selected layer_type.\")\n",
    "        \n",
    "        self.layers = [\n",
    "                StdSplineLayer(\n",
    "                    n_in=layer_dims[i],\n",
    "                    n_out=layer_dims[i + 1],\n",
    "                    **required_parameters,\n",
    "                    seed=seed\n",
    "                )\n",
    "                for i in range(len(layer_dims) - 1)\n",
    "            ]\n",
    "    \n",
    "    def __call__(self, x):\n",
    "\n",
    "        # Pass through each layer of the KAN\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)\n",
    "\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33554d0d-32e6-46f5-8e60-d97470c38b89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utilities\n",
    "def generate_func_data(function, dim, N, seed):\n",
    "    key = jax.random.key(seed)\n",
    "    x = jax.random.uniform(key, shape=(N,dim), minval=-1.0, maxval=1.0)\n",
    "\n",
    "    y = function(x)\n",
    "\n",
    "    return x, y\n",
    "\n",
    "\n",
    "@nnx.jit\n",
    "def func_fit_step(model, optimizer, X_train, y_train):\n",
    "\n",
    "    def loss_fn(model):\n",
    "        residual = model(X_train) - y_train\n",
    "        loss = jnp.mean((residual)**2)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    loss, grads = nnx.value_and_grad(loss_fn)(model)\n",
    "    optimizer.update(grads)\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5033a2f6-a602-4c8a-b70f-2377632c2e16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f138d9f1-1e09-408c-8a4f-675519bd003f",
   "metadata": {},
   "source": [
    "## Grid Search Results - Data Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a42791b1-ca37-4dbd-8cf3-a9b210381423",
   "metadata": {},
   "source": [
    "This part of the notebook simply performs data analysis on the results of the `grid_search.csv` file to produce the results of the manuscript's Table 1. Note that the file `grid_search.csv` (which is supplied in the Supplementary Material) must be on the same directory of this notebook for the following to run successfully. If you only wish to see the results for the trained models, you may skip directly to the section titled [KAN Runs](#KAN-Runs)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ea0ce7c-c525-4305-865a-ab7628f44412",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the results file to perform the analysis\n",
    "gs = pd.read_csv('grid_search.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89defb4-1a29-4d8d-93f7-de7d5b1c48ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Isolate the run with the median performance for confidence\n",
    "gs_sorted = gs.sort_values(\"loss\")\n",
    "\n",
    "# Grouping columns, including pow_res and pow_basis\n",
    "group_cols = ['method', 'function', 'G', 'width', 'depth', 'pow_res', 'pow_basis']\n",
    "\n",
    "# Define a function to get the row with the median loss\n",
    "def get_median_row(group):\n",
    "    median_loss = group['loss'].median()\n",
    "    # Use idxmin on absolute difference to median to break ties predictably\n",
    "    idx = (group['loss'] - median_loss).abs().idxmin()\n",
    "    return group.loc[[idx]]\n",
    "\n",
    "# Apply the function group-wise and reset the index\n",
    "mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d9f140a-4513-44f8-911b-3f8f7e09b2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter to only 'power' method\n",
    "power_df = mgs[mgs['method'] == 'power'].copy()\n",
    "\n",
    "# Group by function and architecture (G, width, depth), and find row with minimal loss\n",
    "best_power_configs = (\n",
    "    power_df\n",
    "    .groupby(['function', 'G', 'width', 'depth'], dropna=False, group_keys=False)\n",
    "    .apply(lambda g: g.loc[g['loss'].idxmin()])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# Drop pow_res and pow_basis from the whole filtered set\n",
    "mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Drop pow_res and pow_basis from best_power_configs too\n",
    "best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Filter out original 'power' rows from mgs_nopow\n",
    "non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']\n",
    "\n",
    "# Combine best 'power' rows with all other methods\n",
    "fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f90d921-644f-4890-9350-b0ba6366b09d",
   "metadata": {},
   "source": [
    "At this point we have a dataframe called `fgs` with a single run per architecture, corresponding to the median results. For each function and each method, we proceed to calculate how many instances outperform the baseline in terms of:\n",
    "\n",
    "a. the final loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec13fbed-bedd-4823-9276-49a293e04e2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Extract baseline rows\n",
    "baseline_df = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss']]\n",
    "baseline_df = baseline_df.rename(columns={'loss': 'baseline_loss'})\n",
    "\n",
    "# Step 2: Filter the methods of interest\n",
    "methods_of_interest = ['lecun_norm', 'lecun_numer', 'power']\n",
    "fgs_comp = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge with baseline on matching config\n",
    "merged = pd.merge(\n",
    "    fgs_comp,\n",
    "    baseline_df,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare losses\n",
    "merged['beats_baseline'] = merged['loss'] < merged['baseline_loss']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result = (\n",
    "    merged.groupby(['function', 'method'])['beats_baseline']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "num_base = baseline_df[baseline_df['function']=='f1'].shape[0]\n",
    "result['percentage'] = 100*result['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "762e32c5-dd98-4d0e-a6da-02a95b3fe9cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05947041-6ec2-49e9-af84-1ceaac9e8e89",
   "metadata": {},
   "source": [
    "b. the final $L^2$ error relative to the reference solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b23ad248-a31b-4e7b-9bc7-57a220e31368",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Get baseline l2 values\n",
    "baseline_l2 = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'l2']]\n",
    "baseline_l2 = baseline_l2.rename(columns={'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Filter the methods of interest again if needed\n",
    "fgs_comp_l2 = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge on config\n",
    "merged_l2 = pd.merge(\n",
    "    fgs_comp_l2,\n",
    "    baseline_l2,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare l2 values\n",
    "merged_l2['beats_baseline_l2'] = merged_l2['l2'] < merged_l2['baseline_l2']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result_l2 = (\n",
    "    merged_l2.groupby(['function', 'method'])['beats_baseline_l2']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_l2['percentage'] = 100*result_l2['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278abdad-a970-486f-9dd8-efc5541e0d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_l2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18892e43-aeb6-43a1-882f-cddfc6416a72",
   "metadata": {},
   "source": [
    "Finally, let's find the number of architectures that minimize the loss and the relative $L^2$ error at the same time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35dbde08-43f9-4ed5-9490-122c5c975ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reuse the merged DataFrame that contains both loss and l2 comparisons\n",
    "# First, make sure both baseline_loss and baseline_l2 are available\n",
    "\n",
    "# Step 1: Merge baseline loss and l2 together\n",
    "baseline_all = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss', 'l2']]\n",
    "baseline_all = baseline_all.rename(columns={'loss': 'baseline_loss', 'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Merge with the methods of interest\n",
    "fgs_comp_all = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "merged_all = pd.merge(\n",
    "    fgs_comp_all,\n",
    "    baseline_all,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 3: Compare both loss and l2\n",
    "merged_all['beats_both'] = (\n",
    "    (merged_all['loss'] < merged_all['baseline_loss']) &\n",
    "    (merged_all['l2'] < merged_all['baseline_l2'])\n",
    ")\n",
    "\n",
    "# Step 4: Group and count\n",
    "result_both = (\n",
    "    merged_all.groupby(['function', 'method'])['beats_both']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_both['percentage'] = 100*result_both['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1cd9fa-6a74-429b-848a-ad4270c91153",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_both)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d2242ce-334f-425d-93ed-0115fe456b23",
   "metadata": {},
   "source": [
    "## KAN Runs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bada1a25-d482-4725-9d3a-adc7e02467d4",
   "metadata": {},
   "source": [
    "Following the data analysis stage, we proceed with the training of the two networks mentioned in the manuscript to show the evolution of the training loss for each function, under all proposed initialization techniques."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01beb31c-6160-4cb1-bca0-17f5373645e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "func_dict = {\"f1\": f1, \"f2\": f2, \"f3\": f3, \"f4\": f4, \"f5\": f5}\n",
    "\n",
    "N = 5000\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 2000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_lecun_numer = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'lecun', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_lecun_norm = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                           'init_scheme': {'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_lecun_numer = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'lecun', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_lecun_norm = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                           'init_scheme': {'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab47ba8-b2d4-41e7-95c9-8f18dea1eef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment\n",
    "# Initialize results dict\n",
    "results = dict()\n",
    "\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name}.\")\n",
    "    function = func_dict[func_name]\n",
    "    results[func_name] = dict()\n",
    "\n",
    "    results[func_name]['small'] = dict()\n",
    "    results[func_name]['big'] = dict()\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[func_name]['small'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # LeCun-Numerical\n",
    "        numer_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_lecun_numer, seed = seed+run)\n",
    "        numer_opt = nnx.Optimizer(numer_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(numer_model, numer_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['numer'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tLeCun-Numerical model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # LeCun-Normalized\n",
    "        norm_model = StdKAN(layer_dims = layer_dims, required_parameters = params_small_lecun_norm, seed = seed+run)\n",
    "        norm_opt = nnx.Optimizer(norm_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(norm_model, norm_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['norm'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tLeCun-Normalized model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "    # Big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[func_name]['big'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # LeCun-Numerical\n",
    "        numer_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_lecun_numer, seed = seed+run)\n",
    "        numer_opt = nnx.Optimizer(numer_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(numer_model, numer_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['numer'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tLeCun-Numerical model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # LeCun-Normalized\n",
    "        norm_model = StdKAN(layer_dims = layer_dims, required_parameters = params_big_lecun_norm, seed = seed+run)\n",
    "        norm_opt = nnx.Optimizer(norm_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(norm_model, norm_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['norm'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tLeCun-Normalized model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54cbe34-13f2-405e-ab4f-58f193d0588b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting\n",
    "cmap = get_cmap(\"Spectral\")\n",
    "spectral_points = np.linspace(0, 1, 12)\n",
    "color_indices = [-3, 3, 1, -1]\n",
    "\n",
    "init_types = ['baseline', 'numer', 'norm', 'power']\n",
    "architectures = ['small', 'big']\n",
    "func_names = list(results.keys())\n",
    "func_plot_names = [r'$f_1(x,y)$', r'$f_2(x,y)$', r'$f_3(x,y)$', r'$f_4(x,y)$', r'$f_5(x,y)$']\n",
    "\n",
    "colors = [cmap(spectral_points[i]) for i in color_indices]\n",
    "custom_colors = dict(zip(init_types, colors))\n",
    "\n",
    "fig, axes = plt.subplots(2, 5, figsize=(25, 10))\n",
    "\n",
    "for col, func_name in enumerate(func_names):\n",
    "    for row, arch in enumerate(architectures):\n",
    "        ax = axes[row, col]\n",
    "        \n",
    "        for init in init_types:\n",
    "            # Collect all runs for this configuration\n",
    "            runs = []\n",
    "            for run in results[func_name][arch]:\n",
    "                arr = np.array(results[func_name][arch][run][init])\n",
    "                runs.append(arr)\n",
    "            runs = np.stack(runs)  # shape: (5, num_epochs)\n",
    "\n",
    "            # Compute mean and standard error\n",
    "            mean = runs.mean(axis=0)\n",
    "            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])\n",
    "\n",
    "            # Plot mean with stderr shaded area\n",
    "            ax.plot(mean, label=init, color=custom_colors[init])\n",
    "            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])\n",
    "            \n",
    "            #ax.set_xticks([0, 100, 200, 300, 400, 500])\n",
    "            ax.tick_params(axis='both', labelsize=14)\n",
    "\n",
    "        # Labeling\n",
    "        if row == 0:\n",
    "            ax.set_title(func_plot_names[col], fontsize=18)\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(\"Training Loss\", fontsize=16, labelpad=10)\n",
    "        if row == 1:\n",
    "            ax.set_xlabel(\"Training Iteration\", fontsize=16, labelpad=10)\n",
    "        if col == len(func_names) - 1:\n",
    "            ax.text(1.05, 0.5, r'$G = 5$, depth = 2, width = 8' if row == 0 else r'$G = 20$, depth = 3, width = 32', transform=ax.transAxes,\n",
    "                    fontsize=16, rotation=270, va='center', ha='left')\n",
    "\n",
    "        ax.set_yscale('log')\n",
    "        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.35)\n",
    "\n",
    "# Construct legend manually\n",
    "handles = [\n",
    "    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['numer'], label='LeCun–Numerical', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['norm'], label='LeCun–Normalized', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),\n",
    "]\n",
    "\n",
    "# Add global legend\n",
    "fig.legend(handles=handles, loc=\"lower center\", ncol=4, fontsize=18, frameon=False, bbox_to_anchor=(0.5, -0.05))\n",
    "\n",
    "plt.subplots_adjust(hspace=0.2, wspace=0.2, bottom=0.1)\n",
    "\n",
    "#fig.savefig(\"losses.pdf\", bbox_inches='tight')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25258fb5-9886-4327-a783-cc7e626740e7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
