{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9f8c774-389c-4719-9f8d-7de70d9488ec",
   "metadata": {},
   "source": [
    "# Training Loss Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d786cc46-d82d-4aa2-8071-60abcf228a09",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src.functions import *\n",
    "from src.pdes import *\n",
    "from src.utils import *\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "\n",
    "from flax import nnx\n",
    "import optax\n",
    "\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5033a2f6-a602-4c8a-b70f-2377632c2e16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3d2242ce-334f-425d-93ed-0115fe456b23",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Function Fitting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bada1a25-d482-4725-9d3a-adc7e02467d4",
   "metadata": {},
   "source": [
    "We proceed with the training of the two networks mentioned in the manuscript to show the evolution of the training loss for each function, under the selected initialization techniques."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01beb31c-6160-4cb1-bca0-17f5373645e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "func_dict = {\"f1\": f1, \"f2\": f2, \"f3\": f3, \"f4\": f4, \"f5\": f5}\n",
    "\n",
    "N = 5000\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 2000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab47ba8-b2d4-41e7-95c9-8f18dea1eef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize results dict\n",
    "results = dict()\n",
    "\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name}.\")\n",
    "    function = func_dict[func_name]\n",
    "    results[func_name] = dict()\n",
    "\n",
    "    results[func_name]['small'] = dict()\n",
    "    results[func_name]['big'] = dict()\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data (in this case we do not care about mse loss, but we're doing it for consistency)\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[func_name]['small'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['glorot'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['small'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "    # Big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[func_name]['big'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['glorot'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[func_name]['big'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ad600d-4ee6-489e-8baf-75e3de9945e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results for further processing\n",
    "results_dir = 'ff_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"losses.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25258fb5-9886-4327-a783-cc7e626740e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "36f0d1a9-6d2b-4460-8efa-1650bcebd526",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## PDE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "935cea73-0305-46cc-b68f-b7b2698bddac",
   "metadata": {},
   "source": [
    "And likewise for the PDEs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47af42b1-15d1-4bb2-8e3c-07b99fad9994",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "pde_dict = {\"allen-cahn\": ac_res, \"burgers\": burgers_res, \"helmholtz\": helmholtz_res}\n",
    "\n",
    "N_points = 2**6\n",
    "\n",
    "RBA_gamma = 0.999\n",
    "RBA_eta = 0.01\n",
    "\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 5000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d318cd-7ab5-4e28-b1ec-f67f21f22e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment\n",
    "# Initialize results dict\n",
    "results = dict()\n",
    "\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "    results[pde_name] = dict()\n",
    "\n",
    "    results[pde_name]['small'] = dict()\n",
    "    results[pde_name]['big'] = dict()\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[pde_name]['small'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(base_model, base_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['small'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(glorot_model, glorot_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['small'][run]['glorot'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(power_model, power_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['small'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "    # Big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        results[pde_name]['big'][run] = dict()\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(base_model, base_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['big'][run]['baseline'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(glorot_model, glorot_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['big'][run]['glorot'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model: Final Loss = {loss:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "        l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        train_losses = jnp.zeros((num_epochs,))\n",
    "        for epoch in range(num_epochs):\n",
    "            loss, l_E, l_B = train_step(power_model, power_opt, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "            train_losses = train_losses.at[epoch].set(loss)\n",
    "\n",
    "        results[pde_name]['big'][run]['power'] = train_losses.copy()\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model: Final Loss = {loss:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3244f2-df0e-498a-8d19-813779879486",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save results for further processing\n",
    "results_dir = 'pde_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"losses.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1409c858-1e25-40a1-bdf3-6d7794e1698f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
