{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "49a91547-b8f2-4953-8ddb-6606e95b1f90",
   "metadata": {},
   "source": [
    "# Function Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea1b453-8334-418f-a0fb-b18621fe7d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src.functions import *\n",
    "from src.utils import *\n",
    "from src.std_kan import StdKAN\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "import optax\n",
    "from flax import nnx\n",
    "\n",
    "import os\n",
    "\n",
    "# Create the directory if it doesn't exist\n",
    "results_dir = \"ff_results\"\n",
    "os.makedirs(results_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc38056f-f103-4610-ac12-3ac5209cf905",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "20d1fe14-a0a9-4cc5-bd9d-b09afefa2712",
   "metadata": {},
   "source": [
    "## Grid Search Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9308aa86-8ebc-4e69-b0c8-2a4749254cee",
   "metadata": {},
   "outputs": [],
   "source": [
    "func_dict = {\"f1\": f1, \"f2\": f2, \"f3\": f3, \"f4\": f4, \"f5\": f5}\n",
    "\n",
    "N = 5000\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 2000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "G_values = [5, 10, 20, 40]\n",
    "widths = [2, 4, 8, 16, 32, 64]\n",
    "depths = [1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ede15a2-e90a-4c12-89fd-9ca736c433b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "53e4c6ad-f76f-403f-b7c0-06419ec1ba91",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Baseline Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8840357b-bc21-4b65-a606-4e61c4bfb166",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"baseline\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, G, width, depth, run, loss, mse, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577067f9-76b1-47d7-a1a5-27505014ab80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name} function.\")\n",
    "    function = func_dict[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data in case we also want to use test mse as final metric\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'default'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss = func_fit_step(model, optimizer, X_train, y_train)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    y_pred = model(X_test)\n",
    "                    test_mse = mean_squared_error(y_test, y_pred)\n",
    "                \n",
    "                    l2error = func_fit_eval(model, function, 2, resolution=200, make_plot=False)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{func_name}, {G}, {width}, {depth}, {run}, {train_loss}, {test_mse}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tTest MSE: {test_mse:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efcd8a93-c67a-449b-a4ef-47fc23fb6730",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6a57dd2d-e0ad-4f42-9db8-42e5650feba1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## LeCun Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4760812c-a415-4e45-9cfe-0b3b12d37ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"lecun\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, G, width, depth, run, loss, mse, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8ea06f-c278-48bc-ab3d-75e151509c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name} function.\")\n",
    "    function = func_dict[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data in case we also want to use test mse as final metric\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'lecun', 'gain': None, 'distribution':'uniform'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss = func_fit_step(model, optimizer, X_train, y_train)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    y_pred = model(X_test)\n",
    "                    test_mse = mean_squared_error(y_test, y_pred)\n",
    "                \n",
    "                    l2error = func_fit_eval(model, function, 2, resolution=200, make_plot=False)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{func_name}, {G}, {width}, {depth}, {run}, {train_loss}, {test_mse}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tTest MSE: {test_mse:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882b2fed-2e40-49f5-b865-6ffcd172206d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0059349d-a7b8-4937-ba73-0c485390399a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Glorot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52a808b-4f64-4913-9d4b-7ccc26bb83b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"glorot\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, G, width, depth, run, loss, mse, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cac30139-6b86-462e-9b63-b800d0262857",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name} function.\")\n",
    "    function = func_dict[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data in case we also want to use test mse as final metric\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'type': 'glorot', 'gain': None, 'distribution':'uniform', 'sample_size': 10000}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss = func_fit_step(model, optimizer, X_train, y_train)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    y_pred = model(X_test)\n",
    "                    test_mse = mean_squared_error(y_test, y_pred)\n",
    "                \n",
    "                    l2error = func_fit_eval(model, function, 2, resolution=200, make_plot=False)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{func_name}, {G}, {width}, {depth}, {run}, {train_loss}, {test_mse}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tTest MSE: {test_mse:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65da9502-fe6a-476d-b1c9-af6166e435b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d154f77f-ac46-4e9b-ac19-89aa0a9b24d1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Custom standardization results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77484c08-65b2-4ea2-b049-4f8fdba7b258",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"std\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, G, width, depth, run, loss, mse, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b55af037-8a1e-442e-8b95-b0ad48f1629c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name} function.\")\n",
    "    function = func_dict[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data in case we also want to use test mse as final metric\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "\n",
    "                req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                              'init_scheme': {'gain': None, 'distribution':'uniform'}}\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "                    model = StdKAN(layer_dims = layer_dims, required_parameters = req_params, seed = seed+run)\n",
    "                    optimizer = nnx.Optimizer(model, opt_type)\n",
    "                \n",
    "                    # Train\n",
    "                    for epoch in range(num_epochs):\n",
    "                        train_loss = func_fit_step(model, optimizer, X_train, y_train)\n",
    "                \n",
    "                    # Evaluate\n",
    "                    y_pred = model(X_test)\n",
    "                    test_mse = mean_squared_error(y_test, y_pred)\n",
    "                \n",
    "                    l2error = func_fit_eval(model, function, 2, resolution=200, make_plot=False)\n",
    "                \n",
    "                    # Log results\n",
    "                    new_row = f\"{func_name}, {G}, {width}, {depth}, {run}, {train_loss}, {test_mse}, {l2error}\"\n",
    "                                    \n",
    "                    # Append the row to the file\n",
    "                    with open(results_file, \"a\") as rfile:\n",
    "                        rfile.write(new_row + \"\\n\")\n",
    "\n",
    "                    print(f\"\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tTest MSE: {test_mse:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d877226-6584-4224-8ec2-3b09b3449f42",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "50ef0e81-1eb7-4ef3-aa67-353737417bdb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Empirical Power Law Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b467cd5e-ae50-4dba-bd74-179dac4066d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pows_basis = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.00]\n",
    "pows_res = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.00]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87243712-9d57-4092-9e50-f9b8fa97206c",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"power\"\n",
    "results_file = os.path.join(results_dir, f\"{experiment_name}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, G, width, depth, pow_basis, pow_res, run, loss, mse, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(results_file):\n",
    "    with open(results_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f82c0bd-01b4-48a4-83cf-887448617c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Procedure\n",
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name} function.\")\n",
    "    function = func_dict[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_func_data(function, 2, N, seed)\n",
    "\n",
    "    # Split data in case we also want to use test mse as final metric\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Grid search\n",
    "    for G in G_values:\n",
    "        print(f\"\\tUsing G = {G}.\")\n",
    "\n",
    "        for depth in depths:\n",
    "            for width in widths:\n",
    "\n",
    "                hidden = [width]*depth\n",
    "                layer_dims = [n_in, *hidden, n_out]\n",
    "                \n",
    "                print(f\"\\t\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "                for pow_basis in pows_basis:\n",
    "\n",
    "                    for pow_res in pows_res:\n",
    "\n",
    "                        req_params = {'k': 3, 'G': G, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "                        print(f\"\\t\\t\\tWorking with pow_basis = {pow_basis} and pow_res = {pow_res}.\")\n",
    "                        \n",
    "                        for run in [1, 2, 3]:\n",
    "        \n",
    "                            model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = req_params, seed = seed+run)\n",
    "                            optimizer = nnx.Optimizer(model, opt_type)\n",
    "                        \n",
    "                            # Train\n",
    "                            for epoch in range(num_epochs):\n",
    "                                train_loss = func_fit_step(model, optimizer, X_train, y_train)\n",
    "                        \n",
    "                            # Evaluate\n",
    "                            y_pred = model(X_test)\n",
    "                            test_mse = mean_squared_error(y_test, y_pred)\n",
    "                        \n",
    "                            l2error = func_fit_eval(model, function, 2, resolution=200, make_plot=False)\n",
    "                        \n",
    "                            # Log results\n",
    "                            new_row = f\"{func_name}, {G}, {width}, {depth}, {pow_res}, {pow_basis}, {run}, {train_loss}, {test_mse}, {l2error}\"\n",
    "                                            \n",
    "                            # Append the row to the file\n",
    "                            with open(results_file, \"a\") as rfile:\n",
    "                                rfile.write(new_row + \"\\n\")\n",
    "        \n",
    "                            print(f\"\\t\\t\\t\\t{run}. Final loss: {train_loss:.2e} \\tTest MSE: {test_mse:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb239a2-8c99-4cce-a0a1-24e06f6997e1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
