{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "49a91547-b8f2-4953-8ddb-6606e95b1f90",
   "metadata": {},
   "source": [
    "# PDE Solving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea1b453-8334-418f-a0fb-b18621fe7d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src.pdes import *\n",
    "from src.utils import *\n",
    "from src.std_kan import StdKAN\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "\n",
    "import optax\n",
    "from flax import nnx\n",
    "\n",
    "import os\n",
    "\n",
    "# Create the directory if it doesn't exist\n",
    "results_dir = \"pde_results\"\n",
    "os.makedirs(results_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc38056f-f103-4610-ac12-3ac5209cf905",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "20d1fe14-a0a9-4cc5-bd9d-b09afefa2712",
   "metadata": {},
   "source": [
    "## Grid Search Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9308aa86-8ebc-4e69-b0c8-2a4749254cee",
   "metadata": {},
   "outputs": [],
   "source": [
    "pde_dict = {\"allen-cahn\": ac_res, \"burgers\": burgers_res, \"helmholtz\": helmholtz_res}\n",
    "\n",
    "N_points = 2**6\n",
    "\n",
    "RBA_gamma = 0.999\n",
    "RBA_eta = 0.01\n",
    "\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 5000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "G_values = [5, 10, 20]\n",
    "widths = [2, 4, 8, 16, 32, 64]\n",
    "depths = [1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ede15a2-e90a-4c12-89fd-9ca736c433b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "53e4c6ad-f76f-403f-b7c0-06419ec1ba91",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Baseline Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8840357b-bc21-4b65-a606-4e61c4bfb166",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"baseline\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"pde, G, width, depth, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577067f9-76b1-47d7-a1a5-27505014ab80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'default'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    # Initialize RBA weights\n",
    "                    l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "                    l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss, l_E, l_B = train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    output = model(coords).reshape(refsol.shape)\n",
    "                    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{pde_name}, {G}, {width}, {depth}, {run}, {train_loss}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efcd8a93-c67a-449b-a4ef-47fc23fb6730",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6a57dd2d-e0ad-4f42-9db8-42e5650feba1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## LeCun-like Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4760812c-a415-4e45-9cfe-0b3b12d37ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"lecun\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"pde, G, width, depth, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32a3b5d-9089-4f95-92af-d7bbb73e5091",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'lecun', 'gain': None, 'distribution':'uniform'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    # Initialize RBA weights\n",
    "                    l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "                    l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss, l_E, l_B = train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    output = model(coords).reshape(refsol.shape)\n",
    "                    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{pde_name}, {G}, {width}, {depth}, {run}, {train_loss}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882b2fed-2e40-49f5-b865-6ffcd172206d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0059349d-a7b8-4937-ba73-0c485390399a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Glorot-like Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52a808b-4f64-4913-9d4b-7ccc26bb83b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"glorot\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"pde, G, width, depth, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "949bece4-c46b-439f-a241-21f022e277d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'glorot', 'gain': None, 'distribution':'uniform', 'sample_size': 10000}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    # Initialize RBA weights\n",
    "                    l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "                    l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss, l_E, l_B = train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    output = model(coords).reshape(refsol.shape)\n",
    "                    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{pde_name}, {G}, {width}, {depth}, {run}, {train_loss}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65da9502-fe6a-476d-b1c9-af6166e435b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d154f77f-ac46-4e9b-ac19-89aa0a9b24d1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Custom standardization results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77484c08-65b2-4ea2-b049-4f8fdba7b258",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"std\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"pde, G, width, depth, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da00aedc-c22a-4e74-a84e-52e2541ddc49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'gain': None, 'distribution':'uniform'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    # Initialize RBA weights\n",
    "                    l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "                    l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "\n",
    "                    model = StdKAN(layer_dims = layer_dims, required_parameters = req_params, seed = seed+run)\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss, l_E, l_B = train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    output = model(coords).reshape(refsol.shape)\n",
    "                    l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{pde_name}, {G}, {width}, {depth}, {run}, {train_loss}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d877226-6584-4224-8ec2-3b09b3449f42",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "50ef0e81-1eb7-4ef3-aa67-353737417bdb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Empirical Power Law Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b467cd5e-ae50-4dba-bd74-179dac4066d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pows_basis = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.00]\n",
    "pows_res = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.00]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfbf9cd1-c316-43a0-8540-884b893f2a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"power\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"pde, G, width, depth, pow_basis, pow_res, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b7063fb-c806-4e86-8383-cdafce872251",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for pde_name in pde_dict.keys():\n",
    "    print(f\"Running Experiments for {pde_name} equation.\")\n",
    "    pde_res = pde_dict[pde_name]\n",
    "\n",
    "    # Define the loss function for this PDE\n",
    "    def loss_fn(model, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "\n",
    "        # ------------- PDE ---------------------------- #\n",
    "        pde_residuals = pde_res(model, pde_collocs)\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_pde_res = jnp.abs(pde_residuals)\n",
    "        l_E_new = (RBA_gamma*l_E) + (RBA_eta*abs_pde_res/jnp.max(abs_pde_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_pde = l_E_new * pde_residuals\n",
    "    \n",
    "        # Get loss\n",
    "        pde_loss = jnp.mean(w_resids_pde**2)\n",
    "    \n",
    "    \n",
    "        # ------------- BC ----------------------------- #\n",
    "        bc_residuals = model(bc_collocs) - bc_data\n",
    "    \n",
    "        # Get new RBA weights\n",
    "        abs_bc_res = jnp.abs(bc_residuals)\n",
    "        l_B_new = (RBA_gamma*l_B) + (RBA_eta*abs_bc_res/jnp.max(abs_bc_res))\n",
    "    \n",
    "        # Multiply by RBA weights\n",
    "        w_resids_bc = l_B_new * bc_residuals\n",
    "    \n",
    "        # Loss\n",
    "        bc_loss = jnp.mean(w_resids_bc**2)\n",
    "    \n",
    "        \n",
    "        # ------------- Total --------------------------- #\n",
    "        total_loss = pde_loss + bc_loss\n",
    "    \n",
    "        return total_loss, (l_E_new, l_B_new)\n",
    "        \n",
    "    # Define the train step\n",
    "    @nnx.jit\n",
    "    def train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data):\n",
    "    \n",
    "        (loss, (l_E_new, l_B_new)), grads = nnx.value_and_grad(loss_fn, has_aux = True)(model, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "    \n",
    "        optimizer.update(grads)\n",
    "    \n",
    "        return loss, l_E_new, l_B_new\n",
    "\n",
    "    # Get the reference solution\n",
    "    refsol, coords = get_ref(pde_name)\n",
    "\n",
    "    # Get collocation points\n",
    "    pde_collocs, bc_collocs, bc_data = get_collocs(pde_name, N_points)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = pde_collocs.shape[1], bc_data.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for pow_basis in pows_basis:\n",
    "\n",
    "                    for pow_res in pows_res:\n",
    "\n",
    "                        req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "                        print(f\"\\t\\t\\tWorking with pow_basis = {pow_basis} and pow_res = {pow_res}.\")\n",
    "\n",
    "                        for run in [1, 2, 3]:\n",
    "        \n",
    "                            # Initialize RBA weights\n",
    "                            l_E = jnp.ones((pde_collocs.shape[0], 1))\n",
    "                            l_B = jnp.ones((bc_collocs.shape[0], 1))\n",
    "        \n",
    "                            model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                            optimizer = nnx.Optimizer(model, opt_type)\n",
    "                        \n",
    "                            # Train\n",
    "                            for epoch in range(num_epochs):\n",
    "                                train_loss, l_E, l_B = train_step(model, optimizer, l_E, l_B, pde_collocs, bc_collocs, bc_data)\n",
    "                        \n",
    "                            # Evaluate\n",
    "                            output = model(coords).reshape(refsol.shape)\n",
    "                            l2error = jnp.linalg.norm(output-refsol)/jnp.linalg.norm(refsol)\n",
    "                        \n",
    "                            # Log results\n",
    "                            new_row = f\"{pde_name}, {G}, {width}, {depth}, {pow_basis}, {pow_res}, {run}, {train_loss}, {l2error}\"\n",
    "                                            \n",
    "                            # Append the row to the file\n",
    "                            with open(results_file, \"a\") as rfile:\n",
    "                                rfile.write(new_row + \"\\n\")\n",
    "        \n",
    "                            print(f\"\\t\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acfec727-5640-4c6a-9d40-f86b10d26b1e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
