{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "58774633-7f11-4ffd-a43e-7404cf64ffa4",
   "metadata": {},
   "source": [
    "# Supplementary Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c6d49fa-2b98-448d-8993-04fcc8826fcc",
   "metadata": {},
   "source": [
    "Given that all experiment notebooks (1-6) have run successfully, this notebook generates all the plots shown in the supplementary material of the manuscript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f25e33d-7ace-4a77-b415-7b5f1f961f57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "from src.functions import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "from matplotlib.colors import LogNorm\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "import warnings\n",
    "from pandas.errors import ParserWarning\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f71e13b-64c3-4f2f-a325-7850ea439ba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the directory if it doesn't exist\n",
    "plots_dir = \"supplementary_plots\"\n",
    "os.makedirs(plots_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9994135d-005e-4d9d-9ce2-bf0f41e12140",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bed6138f-ef03-44ad-b198-09efcf56bd21",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Function Fitting Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16f54e0b-ea58-45c6-a182-fd30d2464858",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = sns.color_palette(\"crest\", as_cmap=True)\n",
    "\n",
    "# Generate data grid\n",
    "N = 256\n",
    "lin = np.linspace(-1.0, 1.0, N)\n",
    "X, Y = np.meshgrid(lin, lin)\n",
    "grid = np.stack([X.ravel(), Y.ravel()], axis=-1)  # shape (N^2, 2)\n",
    "\n",
    "# Evaluate and reshape functions\n",
    "functions = [\n",
    "    f1(grid).reshape(N, N),\n",
    "    f2(grid).reshape(N, N),\n",
    "    f3(grid).reshape(N, N),\n",
    "    f4(grid).reshape(N, N),\n",
    "    f5(grid).reshape(N, N)\n",
    "]\n",
    "titles = [r'$f_1(x,y)$', r'$f_2(x,y)$', r'$f_3(x,y)$', r'$f_4(x,y)$', r'$f_5(x,y)$']\n",
    "\n",
    "TITLE_FS = 14\n",
    "LABEL_FS = 12\n",
    "TICK_FS  = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "826c47c8-8294-4d2b-9fe3-16c9ca66f987",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the figure using a 2x7 GridSpec\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "gs = gridspec.GridSpec(2, 7, height_ratios=[1, 1])\n",
    "\n",
    "# Top row: f1 (0:2), f2 (2:4), f3 (4:6)\n",
    "positions_top = [(0, 2), (2, 4), (4, 6)]\n",
    "for i in range(3):\n",
    "    ax = fig.add_subplot(gs[0, positions_top[i][0]:positions_top[i][1]])\n",
    "    im = ax.pcolormesh(X, Y, functions[i], shading='auto', cmap=cmap)\n",
    "    ax.set_title(titles[i], fontsize=TITLE_FS)\n",
    "    ax.set_xlabel('x', fontsize=LABEL_FS)\n",
    "    ax.set_ylabel('y', fontsize=LABEL_FS)\n",
    "    ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "    cbar = fig.colorbar(im, ax=ax)\n",
    "    cbar.ax.tick_params(labelsize=TICK_FS)\n",
    "\n",
    "# Bottom row: f4 (1:3), f5 (4:6) — centered below f2\n",
    "ax4 = fig.add_subplot(gs[1, 1:3])\n",
    "im4 = ax4.pcolormesh(X, Y, functions[3], shading='auto', cmap=cmap)\n",
    "ax4.set_title(titles[3], fontsize=TITLE_FS)\n",
    "ax4.set_xlabel('x', fontsize=LABEL_FS)\n",
    "ax4.set_ylabel('y', fontsize=LABEL_FS)\n",
    "ax4.tick_params(axis='both', labelsize=TICK_FS)\n",
    "cbar4 = fig.colorbar(im4, ax=ax4)\n",
    "cbar4.ax.tick_params(labelsize=TICK_FS)\n",
    "\n",
    "ax5 = fig.add_subplot(gs[1, 3:5])\n",
    "im5 = ax5.pcolormesh(X, Y, functions[4], shading='auto', cmap=cmap)\n",
    "ax5.set_title(titles[4], fontsize=TITLE_FS)\n",
    "ax5.set_xlabel('x', fontsize=LABEL_FS)\n",
    "ax5.set_ylabel('y', fontsize=LABEL_FS)\n",
    "ax5.tick_params(axis='both', labelsize=TICK_FS)\n",
    "cbar5 = fig.colorbar(im5, ax=ax5)\n",
    "cbar5.ax.tick_params(labelsize=TICK_FS)\n",
    "\n",
    "plt.tight_layout()\n",
    "fig.savefig(os.path.join(plots_dir, \"functions.png\"), dpi=300, bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "671b998b-1dff-4f80-be33-3c341ece9a4d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2b57042d-38a9-4a7e-9f28-13cfb6efceb3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## PDE Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b5f5c3-91ef-4856-b748-4b465853a1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = sns.color_palette(\"crest\", as_cmap=True)\n",
    "\n",
    "# Load Allen–Cahn (t, x, usol)\n",
    "ac = np.load(os.path.join(\"data\", \"allen-cahn.npz\"))\n",
    "t_ac, x_ac, U_ac = ac[\"t\"], ac[\"x\"], ac[\"usol\"].T\n",
    "T_ac, X_ac = np.meshgrid(t_ac, x_ac, indexing=\"xy\")\n",
    "\n",
    "# Load Burgers (t, x, usol)\n",
    "bg = np.load(os.path.join(\"data\", \"burgers.npz\"))\n",
    "t_bg, x_bg, U_bg = bg[\"t\"], bg[\"x\"], bg[\"usol\"].T\n",
    "T_bg, X_bg = np.meshgrid(t_bg, x_bg, indexing=\"xy\")\n",
    "\n",
    "# Load Helmholtz (x, y, usol)\n",
    "hz = np.load(os.path.join(\"data\", \"helmholtz.npz\"))\n",
    "x_hz, y_hz, U_hz = hz[\"x\"], hz[\"y\"], hz[\"usol\"].T\n",
    "X_hz, Y_hz = np.meshgrid(x_hz, y_hz, indexing=\"xy\")\n",
    "\n",
    "TITLE_FS = 18\n",
    "LABEL_FS = 16\n",
    "TICK_FS  = 14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d1f10d2-e002-4417-ad45-020f22d5a140",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hz.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd28ee3-8d8b-46b8-8344-7f95b1e90560",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_pde, axes = plt.subplots(1, 3, figsize=(16, 4), constrained_layout=True)\n",
    "ims = []\n",
    "\n",
    "# Allen–Cahn\n",
    "im = axes[0].pcolormesh(T_ac, X_ac, U_ac, shading=\"auto\", cmap=cmap)\n",
    "axes[0].set_title(\"Allen–Cahn\", fontsize=TITLE_FS)\n",
    "axes[0].set_xlabel(\"t\", fontsize=LABEL_FS)\n",
    "axes[0].set_ylabel(\"x\", fontsize=LABEL_FS)\n",
    "axes[0].tick_params(axis=\"both\", labelsize=TICK_FS)\n",
    "ims.append(im)\n",
    "\n",
    "\n",
    "# Burgers\n",
    "im = axes[1].pcolormesh(T_bg, X_bg, U_bg, shading=\"auto\", cmap=cmap)\n",
    "axes[1].set_title(\"Burgers\", fontsize=TITLE_FS)\n",
    "axes[1].set_xlabel(\"t\", fontsize=LABEL_FS)\n",
    "axes[1].set_ylabel(\"x\", fontsize=LABEL_FS)\n",
    "axes[1].tick_params(axis=\"both\", labelsize=TICK_FS)\n",
    "ims.append(im)\n",
    "\n",
    "# Helmholtz\n",
    "im = axes[2].pcolormesh(X_hz, Y_hz, U_hz, shading=\"auto\", cmap=cmap)\n",
    "axes[2].set_title(\"Helmholtz\", fontsize=TITLE_FS)\n",
    "axes[2].set_xlabel(\"x\", fontsize=LABEL_FS)\n",
    "axes[2].set_ylabel(\"y\", fontsize=LABEL_FS)\n",
    "axes[2].tick_params(axis=\"both\", labelsize=TICK_FS)\n",
    "ims.append(im)\n",
    "\n",
    "labels = [r\"$u(x,t)$\", r\"$u(x,t)$\", r\"$u(x,y)$\"]\n",
    "\n",
    "for im, ax, lab in zip(ims, axes, labels):\n",
    "    cbar = fig_pde.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n",
    "    cbar.ax.set_title(lab, pad=6, fontsize=LABEL_FS)\n",
    "    cbar.ax.tick_params(labelsize=TICK_FS)\n",
    "\n",
    "fig_pde.savefig(os.path.join(plots_dir, \"pdes.png\"), dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94481b20-7e44-453a-958d-6c8aaf41fd88",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f319e4e0-a868-4c2f-96e0-e67a42d2b37c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Full Grid-Search Plots: Function Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd02bad6-e415-4db6-b218-b977953e4d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'ff_results/'\n",
    "gs = pd.read_csv(os.path.join(results_dir, 'grid_search.csv'), sep=',')\n",
    "\n",
    "# Isolate the run with the median performance for confidence\n",
    "gs_sorted = gs.sort_values(\"loss\")\n",
    "\n",
    "# Grouping columns, including pow_res and pow_basis\n",
    "group_cols = ['method', 'function', 'G', 'width', 'depth', 'pow_res', 'pow_basis']\n",
    "\n",
    "# Define a function to get the row with the median loss\n",
    "def get_median_row(group):\n",
    "    median_loss = group['loss'].median()\n",
    "    # Use idxmin on absolute difference to median to break ties predictably\n",
    "    idx = (group['loss'] - median_loss).abs().idxmin()\n",
    "    return group.loc[[idx]]\n",
    "\n",
    "# Apply the function group-wise and reset the index\n",
    "mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)\n",
    "\n",
    "# Filter to only 'power' method\n",
    "power_df = mgs[mgs['method'] == 'power'].copy()\n",
    "\n",
    "# Group by function and architecture (G, width, depth), and find row with minimal loss\n",
    "best_power_configs = (\n",
    "    power_df\n",
    "    .groupby(['function', 'G', 'width', 'depth'], dropna=False, group_keys=False)\n",
    "    .apply(lambda g: g.loc[g['loss'].idxmin()])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# Drop pow_res and pow_basis from the whole filtered set\n",
    "mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Drop pow_res and pow_basis from best_power_configs too\n",
    "best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Filter out original 'power' rows from mgs_nopow\n",
    "non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']\n",
    "\n",
    "# Combine best 'power' rows with all other methods\n",
    "fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f15de8c-a25c-43a0-9d84-94f9585df8fb",
   "metadata": {},
   "source": [
    "### Different Initialization Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f8a47f-de5c-4fca-80a4-e0e34104ed17",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_inits_func(G_val):\n",
    "\n",
    "    TITLE_FS = 22\n",
    "    LABEL_FS = 20\n",
    "    TICK_FS  = 18\n",
    "    CBAR_LABEL_FS = 20\n",
    "    CBAR_TICK_FS  = 18\n",
    "    ANNOT_FS = 8\n",
    "\n",
    "    # Set plot config\n",
    "    colormap = sns.color_palette(\"crest_r\", as_cmap=True)\n",
    "    figsize = (25, 20)\n",
    "    \n",
    "    method_dict = {\"baseline\": \"Baseline\", \"lecun_numer\": \"LeCun-Numerical\", \"lecun_norm\": \"LeCun-Normalized\", \"glorot\": \"Glorot\", \"power\": \"Power Law\"}\n",
    "    func_dict = {\"f1\": r\"$f_1(x,y)$\", \"f2\": r\"$f_2(x,y)$\", \"f3\": r\"$f_3(x,y)$\", \"f4\": r\"$f_4(x,y)$\", \"f5\": r\"$f_5(x,y)$\"}\n",
    "    \n",
    "    # Plotting\n",
    "    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=figsize, constrained_layout=True)\n",
    "    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, wspace=0.07, hspace=0.07)\n",
    "\n",
    "    subset = fgs[fgs['G'] == G_val]\n",
    "\n",
    "    for i, func in enumerate(fgs['function'].unique()):\n",
    "        func_subset = subset[subset['function'] == func]\n",
    "\n",
    "        for j, method in enumerate(fgs['method'].unique()):\n",
    "            ax = axes[i, j]\n",
    "            heat_data = func_subset[func_subset['method'] == method].pivot(\n",
    "                index='depth', columns='width', values='loss'\n",
    "            )\n",
    "            heat_data = heat_data.reindex(index=[4, 3, 2, 1])\n",
    "            sns.heatmap(heat_data, ax=ax, cmap=colormap, cbar=True, annot=False, fmt=\".1e\", annot_kws={\"fontsize\": ANNOT_FS})\n",
    "\n",
    "            colorbar = ax.collections[0].colorbar\n",
    "            colorbar.formatter = ScalarFormatter(useMathText=True)\n",
    "            colorbar.formatter.set_powerlimits((-1, 1))\n",
    "            colorbar.update_ticks()\n",
    "\n",
    "            colorbar.ax.tick_params(labelsize=CBAR_TICK_FS)\n",
    "            colorbar.ax.yaxis.offsetText.set_fontsize(CBAR_TICK_FS)\n",
    "            colorbar.ax.yaxis.offsetText.set_horizontalalignment('center')\n",
    "\n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "            \n",
    "            if i == 0:\n",
    "                ax.set_title(method_dict[method], fontsize=TITLE_FS, pad=10)\n",
    "            if j == 4:\n",
    "                cbar = ax.collections[0].colorbar\n",
    "                cbar.set_label(func_dict[func], rotation=270, labelpad=35, fontsize=CBAR_LABEL_FS)\n",
    "            else:\n",
    "                ax.set_ylabel('')\n",
    "            ax.set_xlabel('' if i < 4 else 'Width', fontsize=LABEL_FS)\n",
    "            ax.set_ylabel('Depth' if j == 0 else '', fontsize=LABEL_FS)\n",
    "\n",
    "            # Hide y-axis ticks for non-leftmost plots\n",
    "            if j != 0:\n",
    "                ax.set_yticklabels([])\n",
    "                ax.set_ylabel('')\n",
    "            \n",
    "            # Hide x-axis ticks for non-bottom plots\n",
    "            if i != 4:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.set_xlabel('')\n",
    "    \n",
    "    #fig.savefig(os.path.join(plots_dir, f\"ff_G{G_val}.pdf\"), bbox_inches='tight')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e660e2-b4ba-4535-a86a-afa81b8393e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for G_val in [5, 10, 20, 40]:\n",
    "    plot_all_inits_func(G_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129a9ce9-a715-4669-b4b7-f0e78c7bd39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_all_inits_func(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b251e9-f4cc-438e-88d5-3e6e4397f67d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a54125e7-c45a-4dde-b4f9-8accb22dd877",
   "metadata": {},
   "source": [
    "### Power-law Grid-Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b312b3-e4a9-4e92-8fa6-862fae3e9bdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_power_func(G_val, func):\n",
    "\n",
    "    power_df = mgs.loc[mgs['method']=='power'].loc[mgs['function']==func]\n",
    "    power_df = power_df[['G', 'width', 'depth', 'pow_res', 'pow_basis', 'loss']]\n",
    "    power_df.reset_index(inplace=True, drop=True)\n",
    "\n",
    "    TITLE_FS = 22\n",
    "    LABEL_FS = 20\n",
    "    TICK_FS  = 18\n",
    "    CBAR_LABEL_FS = 20\n",
    "    CBAR_TICK_FS  = 18\n",
    "    \n",
    "    # Set plot config\n",
    "    colormap = sns.color_palette(\"crest_r\", as_cmap=True)\n",
    "    figsize = (20, 25)\n",
    "    alpha_values = sorted(power_df['pow_res'].unique())\n",
    "    beta_values = sorted(power_df['pow_basis'].unique())\n",
    "\n",
    "    \n",
    "    fig, axes = plt.subplots(nrows=6, ncols=4, figsize=figsize, constrained_layout=True)\n",
    "    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, wspace=0.07, hspace=0.07)\n",
    "    subset = power_df[(power_df['G'] == G_val)]\n",
    "\n",
    "    for i, width in enumerate(sorted(subset['width'].unique())):\n",
    "        width_subset = subset[subset['width'] == width]\n",
    "\n",
    "        for j, depth in enumerate(sorted(subset['depth'].unique())):\n",
    "            ax = axes[i, j]\n",
    "            heat_data = width_subset[width_subset['depth'] == depth].pivot(\n",
    "                index='pow_basis', columns='pow_res', values='loss'\n",
    "            )\n",
    "            heat_data = heat_data.reindex(index=sorted(heat_data.index, reverse=True))\n",
    "\n",
    "            # Compute local vmin/vmax for log normalization\n",
    "            local_vmin = np.nanmin(heat_data.values)\n",
    "            local_vmax = np.nanmax(heat_data.values)\n",
    "            \n",
    "            # Avoid errors with zero or negative values (required for LogNorm)\n",
    "            if local_vmin <= 0:\n",
    "                local_vmin = 1e-12\n",
    "            \n",
    "            norm = LogNorm(vmin=local_vmin, vmax=local_vmax)\n",
    "\n",
    "            sns.heatmap(\n",
    "                heat_data,\n",
    "                ax=ax,\n",
    "                cmap=colormap,\n",
    "                cbar=True,\n",
    "                norm=norm\n",
    "            )\n",
    "\n",
    "            xticks = ax.get_xticks()\n",
    "            xticklabels = [float(lbl.get_text()) for lbl in ax.get_xticklabels()]\n",
    "            filtered_xticklabels = [str(l) if l in {0.0, 0.5, 1.0, 1.5, 2.0} else '' for l in xticklabels]\n",
    "            ax.set_xticklabels(filtered_xticklabels)\n",
    "\n",
    "            yticks = ax.get_yticks()\n",
    "            yticklabels = [float(lbl.get_text()) for lbl in ax.get_yticklabels()]\n",
    "            filtered_yticklabels = [str(l) if l in {0.0, 0.5, 1.0, 1.5, 2.0} else '' for l in yticklabels]\n",
    "            ax.set_yticklabels(filtered_yticklabels)\n",
    "\n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "\n",
    "            colorbar = ax.collections[0].colorbar\n",
    "            colorbar.ax.tick_params(labelsize=CBAR_TICK_FS)\n",
    "\n",
    "            if i == 0:\n",
    "                ax.set_title(f\"Depth = {depth}\", fontsize=TITLE_FS, pad=10)\n",
    "            if j == 3:\n",
    "                cbar = ax.collections[0].colorbar\n",
    "                cbar.set_label(f\"Width = {width}\", rotation=270, labelpad=35, fontsize=CBAR_LABEL_FS)\n",
    "            else:\n",
    "                ax.set_ylabel(\"\")\n",
    "            ax.set_xlabel('' if i < 4 else r\"$\\alpha$\", fontsize=LABEL_FS)\n",
    "            ax.set_ylabel(r\"$\\beta$\" if j == 0 else \"\", fontsize=LABEL_FS)\n",
    "\n",
    "            # Hide y-axis ticks for non-leftmost plots\n",
    "            if j != 0:\n",
    "                ax.set_yticklabels([])\n",
    "                ax.set_ylabel('')\n",
    "            \n",
    "            # Hide x-axis ticks for non-bottom plots\n",
    "            if i != 5:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.set_xlabel('')\n",
    "\n",
    "    #fig.savefig(os.path.join(plots_dir, f\"G{G_val}_{func}.png\"), dpi=300, bbox_inches='tight')\n",
    "    fig.savefig(os.path.join(plots_dir, f\"G{G_val}_{func}.pdf\"), bbox_inches='tight')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b395444b-1f4d-4808-b4c3-f984e5507e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "for G_val in [5, 10, 20, 40]:\n",
    "    for func in ['f1', 'f2', 'f3', 'f4', 'f5']:\n",
    "        plot_power_func(G_val, func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3da5c1ce-fd8e-476e-8dc1-64d3407d0b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_func(5, 'f3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b11c901-996d-444f-a2c1-db756b2047b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "5d7257df-4015-43b8-ac47-4826ec1e7b5e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Full Grid-Search Plots: PDE Solving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6e0ee4-5574-4402-b914-e417425c7ab4",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'pde_results/'\n",
    "gs = pd.read_csv(os.path.join(results_dir, 'grid_search.csv'), sep=',')\n",
    "\n",
    "# Isolate the run with the median performance for confidence\n",
    "gs_sorted = gs.sort_values(\"loss\")\n",
    "\n",
    "# Grouping columns, including pow_res and pow_basis\n",
    "group_cols = ['method', 'pde', 'G', 'width', 'depth', 'pow_res', 'pow_basis']\n",
    "\n",
    "# Define a function to get the row with the median loss\n",
    "def get_median_row(group):\n",
    "    s = group['loss'].dropna()\n",
    "    if s.empty:\n",
    "        return group.iloc[0:0]  # drop this experiment (no valid loss)\n",
    "    med = s.median()\n",
    "    idx = (s - med).abs().idxmin()\n",
    "    return group.loc[[idx]]\n",
    "\n",
    "# Apply the function group-wise and reset the index\n",
    "mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)\n",
    "\n",
    "# Filter to only 'power' method\n",
    "power_df = mgs[mgs['method'] == 'power'].copy()\n",
    "\n",
    "# Group by function and architecture (G, width, depth), and find row with minimal loss\n",
    "best_power_configs = (\n",
    "    power_df\n",
    "    .groupby(['pde', 'G', 'width', 'depth'], dropna=False, group_keys=False)\n",
    "    .apply(lambda g: g.loc[g['loss'].idxmin()])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# Drop pow_res and pow_basis from the whole filtered set\n",
    "mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Drop pow_res and pow_basis from best_power_configs too\n",
    "best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Filter out original 'power' rows from mgs_nopow\n",
    "non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']\n",
    "\n",
    "# Combine best 'power' rows with all other methods\n",
    "fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94e5d8a8-cc30-4bfb-86a7-6b2ef604c8fe",
   "metadata": {},
   "source": [
    "### Different Initialization Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59f047c0-8b22-489c-9155-d5b2bc24509f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_inits_pde(G_val):\n",
    "\n",
    "    TITLE_FS = 22\n",
    "    LABEL_FS = 20\n",
    "    TICK_FS  = 18\n",
    "    CBAR_LABEL_FS = 20\n",
    "    CBAR_TICK_FS  = 18\n",
    "    ANNOT_FS = 8\n",
    "\n",
    "    # Set plot config\n",
    "    colormap = sns.color_palette(\"crest_r\", as_cmap=True)\n",
    "    figsize = (25, 12)\n",
    "    \n",
    "    method_dict = {\"baseline\": \"Baseline\", \"lecun_numer\": \"LeCun-Numerical\", \"lecun_norm\": \"LeCun-Normalized\", \"glorot\": \"Glorot\", \"power\": \"Power Law\"}\n",
    "    pde_dict = {\"ac\": \"Allen–Cahn\", \"burgers\": \"Burgers\", \"helmholtz\": \"Helmholtz\"}\n",
    "    \n",
    "    # Plotting\n",
    "    fig, axes = plt.subplots(nrows=3, ncols=5, figsize=figsize, constrained_layout=True)\n",
    "    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, wspace=0.07, hspace=0.07)\n",
    "\n",
    "    subset = fgs[fgs['G'] == G_val]\n",
    "\n",
    "    for i, pde in enumerate(fgs['pde'].unique()):\n",
    "        pde_subset = subset[subset['pde'] == pde]\n",
    "\n",
    "        for j, method in enumerate(fgs['method'].unique()):\n",
    "            ax = axes[i, j]\n",
    "            heat_data = pde_subset[pde_subset['method'] == method].pivot(\n",
    "                index='depth', columns='width', values='loss'\n",
    "            )\n",
    "            heat_data = heat_data.reindex(index=[4, 3, 2, 1])\n",
    "            sns.heatmap(heat_data, ax=ax, cmap=colormap, cbar=True, annot=False, fmt=\".1e\", annot_kws={\"fontsize\": ANNOT_FS})\n",
    "\n",
    "            colorbar = ax.collections[0].colorbar\n",
    "            \n",
    "            colorbar.formatter = ScalarFormatter(useMathText=True)\n",
    "            colorbar.formatter.set_powerlimits((-1, 1))\n",
    "            colorbar.update_ticks()\n",
    "\n",
    "            colorbar.ax.tick_params(labelsize=CBAR_TICK_FS)\n",
    "            colorbar.ax.yaxis.offsetText.set_fontsize(CBAR_TICK_FS)\n",
    "            colorbar.ax.yaxis.offsetText.set_horizontalalignment('center')\n",
    "\n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "            \n",
    "            if i == 0:\n",
    "                ax.set_title(method_dict[method], fontsize=TITLE_FS, pad=10)\n",
    "            if j == 4:\n",
    "                cbar = ax.collections[0].colorbar\n",
    "                cbar.set_label(pde_dict[pde], rotation=270, labelpad=35, fontsize=CBAR_LABEL_FS)\n",
    "            else:\n",
    "                ax.set_ylabel('')\n",
    "            ax.set_xlabel('' if i < 2 else 'Width', fontsize=LABEL_FS)\n",
    "            ax.set_ylabel('Depth' if j == 0 else '', fontsize=LABEL_FS)\n",
    "\n",
    "            # Hide y-axis ticks for non-leftmost plots\n",
    "            if j != 0:\n",
    "                ax.set_yticklabels([])\n",
    "                ax.set_ylabel('')\n",
    "            \n",
    "            # Hide x-axis ticks for non-bottom plots\n",
    "            if i != 2:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.set_xlabel('')\n",
    "    \n",
    "    #fig.savefig(os.path.join(plots_dir, f\"pde_G{G_val}.pdf\"), bbox_inches='tight')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af04adaf-f0a1-4fc0-9070-8add73cff871",
   "metadata": {},
   "outputs": [],
   "source": [
    "for G_val in [5, 10, 20]:\n",
    "    plot_all_inits_pde(G_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46316cd1-db5c-4d43-ad88-f9d7b3ec7453",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_all_inits_pde(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25514341-f282-4422-ab68-554a87c49d0c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8cb01ea8-a377-4081-85cc-726de57ada29",
   "metadata": {},
   "source": [
    "### Power-law Grid-Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c05949-0f16-4eef-9702-09265d30b523",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_power_pde(G_val, pde):\n",
    "\n",
    "    power_df = mgs.loc[mgs['method']=='power'].loc[mgs['pde']==pde]\n",
    "    power_df = power_df[['G', 'width', 'depth', 'pow_res', 'pow_basis', 'loss']]\n",
    "    power_df.reset_index(inplace=True, drop=True)\n",
    "\n",
    "    TITLE_FS = 22\n",
    "    LABEL_FS = 20\n",
    "    TICK_FS  = 18\n",
    "    CBAR_LABEL_FS = 20\n",
    "    CBAR_TICK_FS  = 18\n",
    "    \n",
    "    # Set plot config\n",
    "    colormap = sns.color_palette(\"crest_r\", as_cmap=True)\n",
    "    figsize = (20, 25)\n",
    "    alpha_values = sorted(power_df['pow_res'].unique())\n",
    "    beta_values = sorted(power_df['pow_basis'].unique())\n",
    "\n",
    "    \n",
    "    fig, axes = plt.subplots(nrows=6, ncols=4, figsize=figsize, constrained_layout=True)\n",
    "    fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, wspace=0.07, hspace=0.07)\n",
    "    subset = power_df[(power_df['G'] == G_val)]\n",
    "\n",
    "    for i, width in enumerate(sorted(subset['width'].unique())):\n",
    "        width_subset = subset[subset['width'] == width]\n",
    "\n",
    "        for j, depth in enumerate(sorted(subset['depth'].unique())):\n",
    "            ax = axes[i, j]\n",
    "            heat_data = width_subset[width_subset['depth'] == depth].pivot(\n",
    "                index='pow_basis', columns='pow_res', values='loss'\n",
    "            )\n",
    "            heat_data = heat_data.reindex(index=sorted(heat_data.index, reverse=True))\n",
    "\n",
    "            # Compute local vmin/vmax for log normalization\n",
    "            local_vmin = np.nanmin(heat_data.values)\n",
    "            local_vmax = np.nanmax(heat_data.values)\n",
    "            \n",
    "            # Avoid errors with zero or negative values (required for LogNorm)\n",
    "            if local_vmin <= 0:\n",
    "                local_vmin = 1e-12\n",
    "            \n",
    "            norm = LogNorm(vmin=local_vmin, vmax=local_vmax)\n",
    "\n",
    "            sns.heatmap(\n",
    "                heat_data,\n",
    "                ax=ax,\n",
    "                cmap=colormap,\n",
    "                cbar=True,\n",
    "                norm=norm\n",
    "            )\n",
    "\n",
    "            xticks = ax.get_xticks()\n",
    "            xticklabels = [float(lbl.get_text()) for lbl in ax.get_xticklabels()]\n",
    "            filtered_xticklabels = [str(l) if l in {0.0, 0.5, 1.0, 1.5, 2.0} else '' for l in xticklabels]\n",
    "            ax.set_xticklabels(filtered_xticklabels)\n",
    "\n",
    "            yticks = ax.get_yticks()\n",
    "            yticklabels = [float(lbl.get_text()) for lbl in ax.get_yticklabels()]\n",
    "            filtered_yticklabels = [str(l) if l in {0.0, 0.5, 1.0, 1.5, 2.0} else '' for l in yticklabels]\n",
    "            ax.set_yticklabels(filtered_yticklabels)\n",
    "\n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "\n",
    "            colorbar = ax.collections[0].colorbar\n",
    "            colorbar.ax.tick_params(labelsize=CBAR_TICK_FS)\n",
    "\n",
    "            if i == 0:\n",
    "                ax.set_title(f\"Depth = {depth}\", fontsize=TITLE_FS, pad=10)\n",
    "            if j == 3:\n",
    "                cbar = ax.collections[0].colorbar\n",
    "                cbar.set_label(f\"Width = {width}\", rotation=270, labelpad=35, fontsize=CBAR_LABEL_FS)\n",
    "            else:\n",
    "                ax.set_ylabel(\"\")\n",
    "            ax.set_xlabel('' if i < 4 else r\"$\\alpha$\", fontsize=LABEL_FS)\n",
    "            ax.set_ylabel(r\"$\\beta$\" if j == 0 else \"\", fontsize=LABEL_FS)\n",
    "\n",
    "            # Hide y-axis ticks for non-leftmost plots\n",
    "            if j != 0:\n",
    "                ax.set_yticklabels([])\n",
    "                ax.set_ylabel('')\n",
    "            \n",
    "            # Hide x-axis ticks for non-bottom plots\n",
    "            if i != 5:\n",
    "                ax.set_xticklabels([])\n",
    "                ax.set_xlabel('')\n",
    "\n",
    "    #fig.savefig(os.path.join(plots_dir, f\"G{G_val}_{pde}.png\"), dpi=300, bbox_inches='tight')\n",
    "    fig.savefig(os.path.join(plots_dir, f\"G{G_val}_{pde}.pdf\"), bbox_inches='tight')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c78e341-8e46-456b-b52d-b1b7d4511abf",
   "metadata": {},
   "outputs": [],
   "source": [
    "for G_val in [5, 10, 20]:\n",
    "    for pde in ['ac', 'burgers', 'helmholtz']:\n",
    "        plot_power_pde(G_val, pde)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0222d2c9-e918-4b33-9127-f214771d9a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_power_pde(5, 'ac')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df08650e-4a16-47ac-9a32-2ec84841654c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
