{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "49a91547-b8f2-4953-8ddb-6606e95b1f90",
   "metadata": {},
   "source": [
    "# Feynman Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea1b453-8334-418f-a0fb-b18621fe7d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from src.feynman import *\n",
    "from src.utils import *\n",
    "\n",
    "from jaxkan.KAN import KAN\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import optax\n",
    "from flax import nnx\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import warnings\n",
    "from pandas.errors import ParserWarning\n",
    "warnings.filterwarnings(\"ignore\", category=ParserWarning)\n",
    "\n",
    "# Create the directory if it doesn't exist\n",
    "results_dir = \"feynman_results\"\n",
    "os.makedirs(results_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc38056f-f103-4610-ac12-3ac5209cf905",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "20d1fe14-a0a9-4cc5-bd9d-b09afefa2712",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9308aa86-8ebc-4e69-b0c8-2a4749254cee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "func_dict = {\"f1\": f1, \"f2\": f2, \"f3\": f3, \"f4\": f4, \"f5\": f5, \"f6\": f6, \"f7\": f7, \"f8\": f8, \"f9\": f9, \"f10\": f10,\n",
    "             \"f11\": f11, \"f12\": f12, \"f13\": f13, \"f14\": f14, \"f15\": f15, \"f16\": f16, \"f17\": f17, \"f18\": f18, \"f19\": f19, \"f20\": f20}\n",
    "\n",
    "func_dims = {\"f1\": 2, \"f2\": 3, \"f3\": 2, \"f4\": 2, \"f5\": 2, \"f6\": 2, \"f7\": 2, \"f8\": 2, \"f9\": 3, \"f10\": 2,\n",
    "             \"f11\": 2, \"f12\": 2, \"f13\": 2, \"f14\": 3, \"f15\": 3, \"f16\": 2, \"f17\": 2, \"f18\": 3, \"f19\": 2, \"f20\": 3}\n",
    "\n",
    "N = 5000\n",
    "seed = 42\n",
    "\n",
    "num_epochs = 2000\n",
    "\n",
    "opt_type = optax.adam(learning_rate=0.001)\n",
    "\n",
    "pow_basis = 1.75\n",
    "pow_res = 0.25\n",
    "\n",
    "# --------------------------\n",
    "# Small architecture details\n",
    "# --------------------------\n",
    "G_small = 5\n",
    "hidden_small = [8, 8]\n",
    "\n",
    "params_small_baseline = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_small_glorot = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_small_power = {'k': 3, 'G': G_small, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}\n",
    "\n",
    "# ------------------------\n",
    "# Big architecture details\n",
    "# ------------------------\n",
    "G_big = 20\n",
    "hidden_big = [32, 32, 32]\n",
    "\n",
    "params_big_baseline = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                         'init_scheme': {'type': 'default'}}\n",
    "\n",
    "params_big_glorot = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                            'init_scheme': {'type': 'glorot', 'gain': None, 'distribution': 'uniform'}}\n",
    "\n",
    "params_big_power = {'k': 3, 'G': G_big, 'grid_range': (-1.0, 1.0), 'grid_e': 1.0, 'residual': nnx.silu, 'external_weights': True, 'add_bias': True,\n",
    "                      'init_scheme': {'type': 'power', \"const_b\": 1.0, \"const_r\": 1.0, \"pow_b1\": pow_basis, \"pow_b2\": pow_basis, \"pow_r1\": pow_res, \"pow_r2\": pow_res}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ede15a2-e90a-4c12-89fd-9ca736c433b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "53e4c6ad-f76f-403f-b7c0-06419ec1ba91",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8840357b-bc21-4b65-a606-4e61c4bfb166",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_experiment = \"feynman_small\"\n",
    "small_file = os.path.join(results_dir, f\"{small_experiment}.txt\")\n",
    "\n",
    "big_experiment = \"feynman_big\"\n",
    "big_file = os.path.join(results_dir, f\"{big_experiment}.txt\")\n",
    "\n",
    "# Define the headers\n",
    "header = \"function, method, run, loss, l2\"\n",
    "\n",
    "# Check if the file exists and write the header if it doesn't\n",
    "if not os.path.exists(small_file):\n",
    "    with open(small_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")\n",
    "\n",
    "if not os.path.exists(big_file):\n",
    "    with open(big_file, \"w\") as file:\n",
    "        file.write(header + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81de6cf8-4ec2-4870-a1e8-1785f62ede0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for func_name in func_dict.keys():\n",
    "    print(f\"Running Experiments for {func_name}.\")\n",
    "    function = func_dict[func_name]\n",
    "    func_dim = func_dims[func_name]\n",
    "\n",
    "    # Generate data\n",
    "    x, y = generate_feyn_data(function, func_dim, N, seed)\n",
    "\n",
    "    # Split data, only for uniformity with previous experiments, we don't use the test set anywhere\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=seed)\n",
    "\n",
    "    # Model input/output\n",
    "    n_in, n_out = X_train.shape[1], y_train.shape[1]\n",
    "\n",
    "    # Small architecture\n",
    "    layer_dims = [n_in, *hidden_small, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(base_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, baseline, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(small_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(glorot_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, glorot, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(small_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_small_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(power_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, power, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(small_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")\n",
    "\n",
    "    # Big architecture\n",
    "    layer_dims = [n_in, *hidden_big, n_out]\n",
    "\n",
    "    print(f\"\\tTraining model with dimensions {layer_dims}.\")\n",
    "\n",
    "    # For confidence\n",
    "    for run in [1, 2, 3, 4, 5]:\n",
    "\n",
    "        print(f\"\\t\\tRun No. {run}.\")\n",
    "\n",
    "        # Baseline\n",
    "        base_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_baseline, seed = seed+run)\n",
    "        base_opt = nnx.Optimizer(base_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(base_model, base_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(base_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, baseline, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(big_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tBaseline model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")\n",
    "\n",
    "        # Glorot\n",
    "        glorot_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_glorot, seed = seed+run)\n",
    "        glorot_opt = nnx.Optimizer(glorot_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(glorot_model, glorot_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(glorot_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, glorot, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(big_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tGlorot model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")\n",
    "\n",
    "        # Power Law\n",
    "        power_model = KAN(layer_dims = layer_dims, layer_type = 'spline', required_parameters = params_big_power, seed = seed+run)\n",
    "        power_opt = nnx.Optimizer(power_model, opt_type)\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            loss = func_fit_step(power_model, power_opt, X_train, y_train)\n",
    "\n",
    "        l2error = feyn_fit_eval(power_model, function, func_dim)\n",
    "\n",
    "        # Log results\n",
    "        new_row = f\"{func_name}, power, {run}, {loss}, {l2error}\"\n",
    "                        \n",
    "        # Append the row to the file\n",
    "        with open(big_file, \"a\") as rfile:\n",
    "            rfile.write(new_row + \"\\n\")\n",
    "\n",
    "        print(f\"\\t\\t\\tPower-law model. Final loss = {loss:.2e} \\tRel. L2 Error: {l2error:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df147013-1e27-4a5d-aeb0-82ebb54b9272",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "89b0422d-0090-428d-93de-0c1da6a433e4",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c60eddd-2add-4e40-b9ad-7a1e8ce64947",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = \"loss\"   # l2 or loss\n",
    "agg = \"median\"    # median or mean\n",
    "\n",
    "sdf = pd.read_csv(small_file, sep=', ')\n",
    "s_res = sdf.groupby([\"function\", \"method\"])[metric].agg(agg).reset_index()\n",
    "s_table = s_res.pivot(index=\"function\", columns=\"method\", values=metric).reset_index()\n",
    "\n",
    "bdf = pd.read_csv(big_file, sep=', ')\n",
    "b_res = bdf.groupby([\"function\", \"method\"])[metric].agg(agg).reset_index()\n",
    "b_table = b_res.pivot(index=\"function\", columns=\"method\", values=metric).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621949f8-9b47-45d8-b5ef-38a3b98d633b",
   "metadata": {},
   "outputs": [],
   "source": [
    "s_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea8d31c1-1830-42ca-8c16-b540e48dcea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "b_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f339922b-1aab-45e6-8d9a-3f9273721607",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f0ffcf-15de-45bc-bec7-314279ea7c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = \"l2\"   # l2 or loss\n",
    "agg = \"median\"    # median or mean\n",
    "\n",
    "sdf = pd.read_csv(small_file, sep=', ')\n",
    "s_res = sdf.groupby([\"function\", \"method\"])[metric].agg(agg).reset_index()\n",
    "s_table = s_res.pivot(index=\"function\", columns=\"method\", values=metric).reset_index()\n",
    "\n",
    "bdf = pd.read_csv(big_file, sep=', ')\n",
    "b_res = bdf.groupby([\"function\", \"method\"])[metric].agg(agg).reset_index()\n",
    "b_table = b_res.pivot(index=\"function\", columns=\"method\", values=metric).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e51716bb-bf71-4aba-9fb8-66a55c087c14",
   "metadata": {},
   "outputs": [],
   "source": [
    "s_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "685f1d27-39c3-4115-a619-93cefc63103c",
   "metadata": {},
   "outputs": [],
   "source": [
    "b_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6704e7cf-35b4-40e4-b75f-f614440a678d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
