{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "58774633-7f11-4ffd-a43e-7404cf64ffa4",
   "metadata": {},
   "source": [
    "# Main Plots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c6d49fa-2b98-448d-8993-04fcc8826fcc",
   "metadata": {},
   "source": [
    "Given that all previous notebooks (1-6) have run successfully, this notebook generates all the plots shown in the main text using the previous notebooks' outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f25e33d-7ace-4a77-b415-7b5f1f961f57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f71e13b-64c3-4f2f-a325-7850ea439ba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the directory if it doesn't exist\n",
    "plots_dir = \"main_plots\"\n",
    "os.makedirs(plots_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9994135d-005e-4d9d-9ce2-bf0f41e12140",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bed6138f-ef03-44ad-b198-09efcf56bd21",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Training Curves: Function Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5865c6d5-8077-4945-9be1-79bc14c57742",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'ff_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"losses.pkl\"), \"rb\") as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e441e78-564e-475c-971f-c9fed49f1d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting\n",
    "num_epochs = 2000\n",
    "cmap = sns.color_palette(\"crest\", as_cmap=True)\n",
    "spectral_points = np.linspace(0, 1, 20)\n",
    "color_indices = [0, 10, -1]\n",
    "\n",
    "TITLE_FS = 22\n",
    "LABEL_FS = 20\n",
    "TICK_FS  = 18\n",
    "\n",
    "init_types = ['baseline', 'glorot', 'power']\n",
    "architectures = ['small', 'big']\n",
    "func_names = list(results.keys())\n",
    "func_plot_names = [r'$f_1(x,y)$', r'$f_2(x,y)$', r'$f_3(x,y)$', r'$f_4(x,y)$', r'$f_5(x,y)$']\n",
    "\n",
    "colors = [cmap(spectral_points[i]) for i in color_indices]\n",
    "custom_colors = dict(zip(init_types, colors))\n",
    "\n",
    "fig, axes = plt.subplots(2, 5, figsize=(25, 10))\n",
    "\n",
    "for col, func_name in enumerate(func_names):\n",
    "    for row, arch in enumerate(architectures):\n",
    "        ax = axes[row, col]\n",
    "        \n",
    "        for init in init_types:\n",
    "            # Collect all runs for this configuration\n",
    "            runs = []\n",
    "            for run in results[func_name][arch]:\n",
    "                arr = np.array(results[func_name][arch][run][init])\n",
    "                runs.append(arr)\n",
    "            runs = np.stack(runs)\n",
    "\n",
    "            # Compute mean and standard error\n",
    "            mean = runs.mean(axis=0)\n",
    "            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])\n",
    "\n",
    "            # Plot mean with stderr shaded area\n",
    "            ax.plot(mean, label=init, color=custom_colors[init])\n",
    "            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])\n",
    "            \n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "\n",
    "        # Labeling\n",
    "        if row == 0:\n",
    "            ax.set_title(func_plot_names[col], fontsize=TITLE_FS)\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(\"Training Loss\", fontsize=LABEL_FS, labelpad=10)\n",
    "        if row == 1:\n",
    "            ax.set_xlabel(\"Training Iteration\", fontsize=LABEL_FS, labelpad=10)\n",
    "        if col == len(func_names) - 1:\n",
    "            ax.text(1.10, 0.5, r'$G = 5$, depth = 2, width = 8' if row == 0 else r'$G = 20$, depth = 3, width = 32', transform=ax.transAxes,\n",
    "                    fontsize=TICK_FS, rotation=270, va='center', ha='left')\n",
    "\n",
    "        ax.set_yscale('log')\n",
    "        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)\n",
    "\n",
    "# Construct legend manually\n",
    "handles = [\n",
    "    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['glorot'], label='Glorot', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),\n",
    "]\n",
    "\n",
    "# Add global legend\n",
    "fig.legend(handles=handles, loc=\"lower center\", ncol=4, fontsize=LABEL_FS, frameon=False, bbox_to_anchor=(0.5, -0.05))\n",
    "\n",
    "plt.subplots_adjust(hspace=0.35, wspace=0.3, bottom=0.1)\n",
    "\n",
    "fig.savefig(os.path.join(plots_dir, \"ff_losses.pdf\"), bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49e2ded6-3f8c-4c05-aa45-338df5c4a754",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "62d7f83d-6b0c-4a3b-b480-aa801ceee9cd",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Training Curves: PDEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86409de-08ce-43d1-a40d-29ee6c9df207",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'pde_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"losses.pkl\"), \"rb\") as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a21372e-30ae-49d5-9768-1f2e08c8345e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting\n",
    "num_epochs = 5000\n",
    "cmap = sns.color_palette(\"crest\", as_cmap=True)\n",
    "spectral_points = np.linspace(0, 1, 20)\n",
    "color_indices = [0, 10, -1]\n",
    "\n",
    "TITLE_FS = 18\n",
    "LABEL_FS = 16\n",
    "TICK_FS  = 14\n",
    "\n",
    "init_types = ['baseline', 'glorot', 'power']\n",
    "architectures = ['small', 'big']\n",
    "func_names = list(results.keys())\n",
    "func_plot_names = ['Allen-Cahn', 'Burgers', 'Helmholtz']\n",
    "\n",
    "colors = [cmap(spectral_points[i]) for i in color_indices]\n",
    "custom_colors = dict(zip(init_types, colors))\n",
    "\n",
    "fig, axes = plt.subplots(2, 3, figsize=(20, 7))\n",
    "\n",
    "for col, func_name in enumerate(func_names):\n",
    "    for row, arch in enumerate(architectures):\n",
    "        ax = axes[row, col]\n",
    "        \n",
    "        for init in init_types:\n",
    "            # Collect all runs for this configuration\n",
    "            runs = []\n",
    "            for run in results[func_name][arch]:\n",
    "                arr = np.array(results[func_name][arch][run][init])\n",
    "                runs.append(arr)\n",
    "            runs = np.stack(runs)\n",
    "\n",
    "            # Compute mean and standard error\n",
    "            mean = runs.mean(axis=0)\n",
    "            stderr = runs.std(axis=0) / np.sqrt(runs.shape[0])\n",
    "\n",
    "            # Plot mean with stderr shaded area\n",
    "            ax.plot(mean, label=init, color=custom_colors[init])\n",
    "            ax.fill_between(np.arange(num_epochs), mean - stderr, mean + stderr, alpha=0.3, color=custom_colors[init])\n",
    "            \n",
    "            ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "\n",
    "        # Labeling\n",
    "        if row == 0:\n",
    "            ax.set_title(func_plot_names[col], fontsize=TITLE_FS)\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(\"Training Loss\", fontsize=LABEL_FS, labelpad=10)\n",
    "        if row == 1:\n",
    "            ax.set_xlabel(\"Training Iteration\", fontsize=LABEL_FS, labelpad=10)\n",
    "        if col == len(func_names) - 1:\n",
    "            ax.text(1.10, 0.5, '            G = 5\\ndepth = 2, width = 8' if row == 0 else '            G = 20\\ndepth = 3, width = 32', transform=ax.transAxes,\n",
    "                    fontsize=LABEL_FS, rotation=270, va='center', ha='left')\n",
    "        \n",
    "        ax.set_yscale('log')\n",
    "        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)\n",
    "\n",
    "# Construct legend manually\n",
    "handles = [\n",
    "    mlines.Line2D([], [], color=custom_colors['baseline'], label='Baseline', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['glorot'], label='Glorot', linewidth=3),\n",
    "    mlines.Line2D([], [], color=custom_colors['power'], label='Power-Law', linewidth=3),\n",
    "]\n",
    "\n",
    "# Add global legend\n",
    "fig.legend(handles=handles, loc=\"lower center\", ncol=4, fontsize=LABEL_FS, frameon=False, bbox_to_anchor=(0.5, -0.08))\n",
    "\n",
    "plt.subplots_adjust(hspace=0.25, wspace=0.2, bottom=0.1)\n",
    "\n",
    "fig.savefig(os.path.join(plots_dir, \"pde_losses.pdf\"), bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "365beae6-e0b7-46c2-92f3-d98140ae1019",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "aad484ab-26c2-4e6d-90b7-9ac997b8cf83",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## NTK Plots: Function Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707b8ddd-d211-4094-932e-d8986772084d",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'ff_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"ntk.pkl\"), \"rb\") as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "041eb70f-91e2-4466-8858-f8255c764763",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ntk_row(results, func_name):\n",
    "\n",
    "    TITLE_FS = 14\n",
    "    LABEL_FS = 12\n",
    "    TICK_FS  = 10\n",
    "\n",
    "    palette = sns.color_palette(\"crest\", 20)\n",
    "    c_init  = palette[-1]\n",
    "    c_mid   = palette[10]\n",
    "    c_final = palette[0]\n",
    "\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(16, 3), sharex=True, sharey=False)\n",
    "\n",
    "    for col, init in enumerate([\"Baseline\", \"Glorot\", \"Power\"]):\n",
    "        ax = axes[col]\n",
    "        rec = results[func_name][\"big\"].get(init)\n",
    "        if rec is None:\n",
    "            ax.set_visible(False)\n",
    "            continue\n",
    "\n",
    "        specs = [np.asarray(e) for e in rec[\"spec_list\"]]\n",
    "        taus  = rec[\"tau_list\"]\n",
    "        idx   = np.arange(1, specs[0].size + 1)\n",
    "\n",
    "        # Initialization\n",
    "        ax.plot(idx, specs[0], color=c_init, lw=2, label=\"Initialization\")\n",
    "        # Intermediates\n",
    "        if len(specs) > 2:\n",
    "            for lam in specs[1:-1]:\n",
    "                ax.plot(idx, lam, \"--\", color=c_mid, alpha=0.7, label=\"Intermediate Iterations\" if col==0 else None)\n",
    "        # Final\n",
    "        ax.plot(idx, specs[-1], \"--\", color=c_final, lw=2, label=\"Final Iteration\")\n",
    "\n",
    "        ax.set_xscale(\"log\"); ax.set_yscale(\"log\")\n",
    "        title = f\"{init}\" if init != \"Power\" else \"Power-Law\"\n",
    "        ax.set_title(title, fontsize=TITLE_FS)\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(\"Eigenvalues\", fontsize=LABEL_FS)\n",
    "        ax.set_xlabel(\"Indices\", fontsize=LABEL_FS)\n",
    "        ax.tick_params(axis='both', labelsize=TICK_FS)\n",
    "\n",
    "        ax.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)\n",
    "\n",
    "        # Right-side row annotation (like your style)\n",
    "        #if col == len(INIT_ORDER) - 1:\n",
    "        #    ax.text(1.05, 0.5, \"G = 20, depth = 3, width = 32\", transform=ax.transAxes,\n",
    "        #            fontsize=14, rotation=270, va='center', ha='left')\n",
    "\n",
    "    # Global legend outside\n",
    "    handles = [\n",
    "        mlines.Line2D([], [], color=c_init,  label=\"Initialization\",        linewidth=2),\n",
    "        mlines.Line2D([], [], color=c_mid,   label=\"Intermediate Iterations\", linewidth=2, linestyle=\"--\"),\n",
    "        mlines.Line2D([], [], color=c_final, label=\"Final Iteration\",       linewidth=2, linestyle=\"--\"),\n",
    "    ]\n",
    "    \n",
    "    fig.legend(handles=handles, loc=\"lower center\", ncol=3, fontsize=LABEL_FS,\n",
    "               frameon=False, bbox_to_anchor=(0.5, -0.15))\n",
    "\n",
    "    plt.subplots_adjust(hspace=0.2, wspace=0.2, bottom=0.18)\n",
    "\n",
    "    fig.savefig(os.path.join(plots_dir, f\"ntk_{func_name}.pdf\"), bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef497197-d97b-4f26-a1f2-d300e8e7a079",
   "metadata": {},
   "outputs": [],
   "source": [
    "for func_name in [\"f1\", \"f2\", \"f3\", \"f4\", \"f5\"]:\n",
    "    plot_ntk_row(results, func_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a03a18-fab4-477d-80b6-d27ce09dfcc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ntk_row(results, \"f3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0799da7-6ca8-45ff-a509-e0edf2624d89",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "382f920c-85e5-4fec-ac1d-f161bb269ec6",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## NTK Plots: PDEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf25bea-232e-4e9c-990b-9be63fe08c5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'pde_results/'\n",
    "\n",
    "with open(os.path.join(results_dir, \"ntk.pkl\"), \"rb\") as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "971f1434-1036-40cc-b176-204055b36bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ntk_rows(results, pde_name):\n",
    "\n",
    "    TITLE_FS = 14\n",
    "    LABEL_FS = 12\n",
    "    TICK_FS  = 10\n",
    "\n",
    "    palette = sns.color_palette(\"crest\", 20)\n",
    "    c_init  = palette[-1]\n",
    "    c_mid   = palette[10]\n",
    "    c_final = palette[0]\n",
    "\n",
    "    fig, axes = plt.subplots(2, 3, figsize=(16, 6), sharex=False, sharey=False)\n",
    "\n",
    "    for col, init in enumerate([\"Baseline\", \"Glorot\", \"Power\"]):\n",
    "        rec = results[pde_name][\"big\"].get(init)\n",
    "        if rec is None:\n",
    "            axes[0, col].set_visible(False)\n",
    "            axes[1, col].set_visible(False)\n",
    "            continue\n",
    "\n",
    "        # ---- Row 0: PDE spectra ----\n",
    "        ax0 = axes[0, col]\n",
    "        pde_specs = [np.asarray(e) for e in rec[\"specE_list\"]]\n",
    "        taus      = rec[\"tau_list\"]\n",
    "        idx       = np.arange(1, pde_specs[0].size + 1)\n",
    "\n",
    "        ax0.plot(idx, pde_specs[0], color=c_init, lw=2, label=\"Initialization\")\n",
    "        if len(pde_specs) > 2:\n",
    "            for lam in pde_specs[1:-1]:\n",
    "                ax0.plot(idx, lam, \"--\", color=c_mid, alpha=0.7,\n",
    "                         label=\"Intermediate Iterations\" if (col==0) else None)\n",
    "        ax0.plot(idx, pde_specs[-1], \"--\", color=c_final, lw=3, label=\"Final Iteration\")\n",
    "\n",
    "        ax0.set_xscale(\"log\"); ax0.set_yscale(\"log\")\n",
    "        title = f\"{init}\" if init != \"Power\" else \"Power-Law\"\n",
    "        ax0.set_title(title, fontsize=TITLE_FS)\n",
    "        if col == 0:\n",
    "            ax0.set_ylabel(\"Eigenvalues (PDE)\", fontsize=LABEL_FS)\n",
    "        ax0.tick_params(axis='both', labelsize=TICK_FS)\n",
    "        ax0.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)\n",
    "\n",
    "        # ---- Row 1: BC spectra ----\n",
    "        ax1 = axes[1, col]\n",
    "        bc_specs = [np.asarray(e) for e in rec[\"specB_list\"]]\n",
    "        idx_b    = np.arange(1, bc_specs[0].size + 1)\n",
    "\n",
    "        ax1.plot(idx_b, bc_specs[0], color=c_init, lw=2)\n",
    "        if len(bc_specs) > 2:\n",
    "            for lam in bc_specs[1:-1]:\n",
    "                ax1.plot(idx_b, lam, \"--\", color=c_mid, alpha=0.7)\n",
    "        ax1.plot(idx_b, bc_specs[-1], \"--\", color=c_final, lw=2)\n",
    "\n",
    "        ax1.set_xscale(\"log\"); ax1.set_yscale(\"log\")\n",
    "        if col == 0:\n",
    "            ax1.set_ylabel(\"Eigenvalues (BC)\", fontsize=LABEL_FS)\n",
    "        ax1.set_xlabel(\"Indices\", fontsize=LABEL_FS)\n",
    "        ax1.tick_params(axis='both', labelsize=TICK_FS)\n",
    "        ax1.grid(True, which='both', linestyle='--', linewidth=0.25, alpha=0.55)\n",
    "\n",
    "    # Global legend outside (one set for all panels)\n",
    "    handles = [\n",
    "        mlines.Line2D([], [], color=c_init,  label=\"Initialization\",          linewidth=2),\n",
    "        mlines.Line2D([], [], color=c_mid,   label=\"Intermediate Iterations\", linewidth=2, linestyle=\"--\"),\n",
    "        mlines.Line2D([], [], color=c_final, label=\"Final Iteration\",         linewidth=2, linestyle=\"--\"),\n",
    "    ]\n",
    "    \n",
    "    fig.legend(handles=handles, loc=\"lower center\", ncol=3, fontsize=LABEL_FS,\n",
    "               frameon=False, bbox_to_anchor=(0.5, -0.02))\n",
    "\n",
    "    plt.subplots_adjust(hspace=0.25, wspace=0.2, bottom=0.16)\n",
    "\n",
    "    fig.savefig(os.path.join(plots_dir, f\"ntk_{pde_name}.pdf\"), bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffda1afa-fcd9-49c5-9b2c-28e61f198645",
   "metadata": {},
   "outputs": [],
   "source": [
    "for pde_name in [\"allen-cahn\", \"burgers\", \"helmholtz\"]:\n",
    "    plot_ntk_rows(results, pde_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd28ee3-8d8b-46b8-8344-7f95b1e90560",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ntk_rows(results, 'allen-cahn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9bec0e4-17da-4499-b550-7071be8a286f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
