{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "691a29e3-d4c4-4428-8711-b0ec0ba096fa",
   "metadata": {},
   "source": [
    "# Grid Search Results Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea616686-40d9-4f54-846c-165422806339",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Function Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04434c09-b568-4f6b-baff-482f941cf1ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from glob import glob\n",
    "\n",
    "import warnings\n",
    "from pandas.errors import ParserWarning\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=ParserWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "\n",
    "results_dir = 'ff_results/'\n",
    "all_files = glob(os.path.join(results_dir, '*.txt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc5de925-684f-4c62-a214-945c98cf4dea",
   "metadata": {},
   "outputs": [],
   "source": [
    "method_dict = {'lecun' : 'lecun_numer', 'glorot': 'glorot', 'std' : 'lecun_norm', 'power' : 'power', 'baseline' : 'baseline'}\n",
    "\n",
    "# List to collect dataframes\n",
    "dfs = []\n",
    "\n",
    "for file_path in all_files:\n",
    "    method = os.path.splitext(os.path.basename(file_path))[0]\n",
    "    df = pd.read_csv(file_path, sep=', ')\n",
    "    df['method'] = method_dict[method]\n",
    "    \n",
    "    # Ensure all expected columns are present\n",
    "    for col in ['pow_res', 'pow_basis']:\n",
    "        if col not in df.columns:\n",
    "            df[col] = pd.NA\n",
    "    \n",
    "    dfs.append(df)\n",
    "\n",
    "# Combine all into one DataFrame\n",
    "gsdf = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "gs = gsdf[['method', 'function', 'G', 'width', 'depth', 'pow_res', 'pow_basis', 'run', 'loss', 'l2']]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d83b2ec-ae8c-4b68-8786-07a1684b38b2",
   "metadata": {},
   "source": [
    "Save the data to a single csv for possible further processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f32a5616-911c-4d11-9fd8-f7cc9e4810a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "gs.to_csv(os.path.join(results_dir, 'grid_search.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221da946-4375-48bd-a1b4-cfc6461fc9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Isolate the run with the median performance for confidence\n",
    "gs_sorted = gs.sort_values(\"loss\")\n",
    "\n",
    "# Grouping columns, including pow_res and pow_basis\n",
    "group_cols = ['method', 'function', 'G', 'width', 'depth', 'pow_res', 'pow_basis']\n",
    "\n",
    "# Define a function to get the row with the median loss\n",
    "def get_median_row(group):\n",
    "    median_loss = group['loss'].median()\n",
    "    # Use idxmin on absolute difference to median to break ties predictably\n",
    "    idx = (group['loss'] - median_loss).abs().idxmin()\n",
    "    return group.loc[[idx]]\n",
    "\n",
    "# Apply the function group-wise and reset the index\n",
    "mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63564855-f11c-42c8-b0bc-eb32a0f46eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter to only 'power' method\n",
    "power_df = mgs[mgs['method'] == 'power'].copy()\n",
    "\n",
    "# Group by function and architecture (G, width, depth), and find row with minimal loss\n",
    "best_power_configs = (\n",
    "    power_df\n",
    "    .groupby(['function', 'G', 'width', 'depth'], dropna=False, group_keys=False)\n",
    "    .apply(lambda g: g.loc[g['loss'].idxmin()])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# Drop pow_res and pow_basis from the whole filtered set\n",
    "mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Drop pow_res and pow_basis from best_power_configs too\n",
    "best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Filter out original 'power' rows from mgs_nopow\n",
    "non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']\n",
    "\n",
    "# Combine best 'power' rows with all other methods\n",
    "fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353776a5-1400-48d3-894c-d0136d7a8e7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fgs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cc4f257-315b-42ba-b32d-4d7562c1cb4e",
   "metadata": {},
   "source": [
    "At this point we have a dataframe called `fgs` with a single run per architecture, corresponding to the median results. For each function and each method, we proceed to calculate how many instances outperform the baseline in terms of:\n",
    "\n",
    "a. the final loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c34c8aa-ed3a-4400-8779-ec71400e1a7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Extract baseline rows\n",
    "baseline_df = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss']]\n",
    "baseline_df = baseline_df.rename(columns={'loss': 'baseline_loss'})\n",
    "\n",
    "# Step 2: Filter the methods of interest\n",
    "methods_of_interest = ['glorot', 'lecun_norm', 'lecun_numer', 'power']\n",
    "fgs_comp = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge with baseline on matching config\n",
    "merged = pd.merge(\n",
    "    fgs_comp,\n",
    "    baseline_df,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare losses\n",
    "merged['beats_baseline'] = merged['loss'] < merged['baseline_loss']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result = (\n",
    "    merged.groupby(['function', 'method'])['beats_baseline']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "num_base = baseline_df[baseline_df['function']=='f1'].shape[0]\n",
    "result['percentage'] = 100*result['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f98e1b0d-f90c-4b56-a4e2-c7aae43478dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c28b30-5e37-4f50-8318-3fb7dd19a923",
   "metadata": {},
   "source": [
    "b. the final $L^2$ error relative to the reference solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d82e2ad-81cd-4e1f-93ed-b4d29baca8a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Get baseline l2 values\n",
    "baseline_l2 = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'l2']]\n",
    "baseline_l2 = baseline_l2.rename(columns={'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Filter the methods of interest again if needed\n",
    "fgs_comp_l2 = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge on config\n",
    "merged_l2 = pd.merge(\n",
    "    fgs_comp_l2,\n",
    "    baseline_l2,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare l2 values\n",
    "merged_l2['beats_baseline_l2'] = merged_l2['l2'] < merged_l2['baseline_l2']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result_l2 = (\n",
    "    merged_l2.groupby(['function', 'method'])['beats_baseline_l2']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_l2['percentage'] = 100*result_l2['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969e932f-1d78-4382-963d-508ec8cbed02",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_l2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24975ae-be09-42ee-b0e3-61ba4114612a",
   "metadata": {},
   "source": [
    "Finally, let's find the number of architectures that minimize the loss and the relative $L^2$ error at the same time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d403e47-bb70-4f59-a54c-f840c913e319",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reuse the merged DataFrame that contains both loss and l2 comparisons\n",
    "# First, make sure both baseline_loss and baseline_l2 are available\n",
    "\n",
    "# Step 1: Merge baseline loss and l2 together\n",
    "baseline_all = fgs[fgs['method'] == 'baseline'][['function', 'G', 'depth', 'width', 'loss', 'l2']]\n",
    "baseline_all = baseline_all.rename(columns={'loss': 'baseline_loss', 'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Merge with the methods of interest\n",
    "fgs_comp_all = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "merged_all = pd.merge(\n",
    "    fgs_comp_all,\n",
    "    baseline_all,\n",
    "    on=['function', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 3: Compare both loss and l2\n",
    "merged_all['beats_both'] = (\n",
    "    (merged_all['loss'] < merged_all['baseline_loss']) &\n",
    "    (merged_all['l2'] < merged_all['baseline_l2'])\n",
    ")\n",
    "\n",
    "# Step 4: Group and count\n",
    "result_both = (\n",
    "    merged_all.groupby(['function', 'method'])['beats_both']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_both['percentage'] = 100*result_both['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a25e4e-fa93-4f35-8638-305888818abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_both)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b817546-a4f6-4bf4-a370-b05297e72f7a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "53c298b3-1d8b-4389-9b90-0060cc2fab8f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## PDE Solving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b027da-37c4-45a4-99d7-d93e021af3fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = 'pde_results/'\n",
    "all_files = glob(os.path.join(results_dir, '*.txt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "077250a4-043f-442c-969f-befe2f74ecb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "method_dict = {'glorot': 'glorot', 'lecun' : 'lecun_numer', 'std' : 'lecun_norm', 'power' : 'power', 'baseline' : 'baseline'}\n",
    "\n",
    "# List to collect dataframes\n",
    "dfs = []\n",
    "\n",
    "for file_path in all_files:\n",
    "    method = os.path.splitext(os.path.basename(file_path))[0]\n",
    "    df = pd.read_csv(file_path, sep=', ')\n",
    "    df['method'] = method_dict[method]\n",
    "    \n",
    "    # Ensure all expected columns are present\n",
    "    for col in ['pow_res', 'pow_basis']:\n",
    "        if col not in df.columns:\n",
    "            df[col] = pd.NA\n",
    "    \n",
    "    dfs.append(df)\n",
    "\n",
    "# Combine all into one DataFrame\n",
    "gsdf = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "gs = gsdf[['method', 'pde', 'G', 'width', 'depth', 'pow_res', 'pow_basis', 'run', 'loss', 'l2']]\n",
    "\n",
    "rename_dict = {'allen-cahn':'ac', 'burgers':'burgers', 'helmholtz':'helmholtz'}\n",
    "\n",
    "gs['pde'] = gs['pde'].replace(rename_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28652060-c672-4222-89fb-4fe54b82e5f5",
   "metadata": {},
   "source": [
    "Save the data to a single csv for possible further processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3fd256c-1026-41ee-8703-007c04ac0707",
   "metadata": {},
   "outputs": [],
   "source": [
    "gs.to_csv(os.path.join(results_dir, 'grid_search.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d7645a-91dd-4129-bab9-4f92a4ce394f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Isolate the run with the median performance for confidence\n",
    "gs_sorted = gs.sort_values(\"loss\")\n",
    "\n",
    "# Grouping columns, including pow_res and pow_basis\n",
    "group_cols = ['method', 'pde', 'G', 'width', 'depth', 'pow_res', 'pow_basis']\n",
    "\n",
    "# Define a function to get the row with the median loss\n",
    "def get_median_row(group):\n",
    "    s = group['loss'].dropna()\n",
    "    if s.empty:\n",
    "        return group.iloc[0:0]  # drop this experiment (no valid loss)\n",
    "    med = s.median()\n",
    "    idx = (s - med).abs().idxmin()\n",
    "    return group.loc[[idx]]\n",
    "\n",
    "# Apply the function group-wise and reset the index\n",
    "mgs = gs_sorted.groupby(group_cols, dropna=False, group_keys=False).apply(get_median_row).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5f8b3b-fe72-4a79-8353-62b90e409cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter to only 'power' method\n",
    "power_df = mgs[mgs['method'] == 'power'].copy()\n",
    "\n",
    "# Group by function and architecture (G, width, depth), and find row with minimal loss\n",
    "best_power_configs = (\n",
    "    power_df\n",
    "    .groupby(['pde', 'G', 'width', 'depth'], dropna=False, group_keys=False)\n",
    "    .apply(lambda g: g.loc[g['loss'].idxmin()])\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "# Drop pow_res and pow_basis from the whole filtered set\n",
    "mgs_nopow = mgs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Drop pow_res and pow_basis from best_power_configs too\n",
    "best_power_configs_nopow = best_power_configs.drop(columns=['pow_res', 'pow_basis', 'run'])\n",
    "\n",
    "# Filter out original 'power' rows from mgs_nopow\n",
    "non_power_rows = mgs_nopow[mgs_nopow['method'] != 'power']\n",
    "\n",
    "# Combine best 'power' rows with all other methods\n",
    "fgs = pd.concat([non_power_rows, best_power_configs_nopow], ignore_index=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9738f258-8451-43cd-bd89-d9d90fad387e",
   "metadata": {},
   "source": [
    "At this point we have a dataframe called `fgs` with a single run per architecture, corresponding to the median results. For each pde and each method, we proceed to calculate how many instances outperform the baseline in terms of:\n",
    "\n",
    "a. the final loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc341b64-d211-41f1-b5c0-e33834773465",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Extract baseline rows\n",
    "baseline_df = fgs[fgs['method'] == 'baseline'][['pde', 'G', 'depth', 'width', 'loss']]\n",
    "baseline_df = baseline_df.rename(columns={'loss': 'baseline_loss'})\n",
    "\n",
    "# Step 2: Filter the methods of interest\n",
    "methods_of_interest = ['lecun_norm', 'lecun_numer', 'glorot', 'power']\n",
    "fgs_comp = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge with baseline on matching config\n",
    "merged = pd.merge(\n",
    "    fgs_comp,\n",
    "    baseline_df,\n",
    "    on=['pde', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare losses\n",
    "merged['beats_baseline'] = merged['loss'] < merged['baseline_loss']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result = (\n",
    "    merged.groupby(['pde', 'method'])['beats_baseline']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "num_base = baseline_df[baseline_df['pde']=='ac'].shape[0]\n",
    "result['percentage'] = 100*result['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6de559-b85f-4d0f-9afb-dc38cc72324b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37b75a43-98ca-4b3e-bd88-37eaa13e9bdf",
   "metadata": {},
   "source": [
    "b. the final $L^2$ error relative to the reference solution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c95d4b-94dd-4554-978d-464fb4e29f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Get baseline l2 values\n",
    "baseline_l2 = fgs[fgs['method'] == 'baseline'][['pde', 'G', 'depth', 'width', 'l2']]\n",
    "baseline_l2 = baseline_l2.rename(columns={'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Filter the methods of interest again if needed\n",
    "fgs_comp_l2 = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "\n",
    "# Step 3: Merge on config\n",
    "merged_l2 = pd.merge(\n",
    "    fgs_comp_l2,\n",
    "    baseline_l2,\n",
    "    on=['pde', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 4: Compare l2 values\n",
    "merged_l2['beats_baseline_l2'] = merged_l2['l2'] < merged_l2['baseline_l2']\n",
    "\n",
    "# Step 5: Group and count\n",
    "result_l2 = (\n",
    "    merged_l2.groupby(['pde', 'method'])['beats_baseline_l2']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_l2['percentage'] = 100*result_l2['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c3402d-1328-4d74-8b6a-f1b16e2d5326",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_l2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6b3dedc-f755-4838-ae6d-8c8abb2292ae",
   "metadata": {},
   "source": [
    "Finally, let's find the number of architectures that minimize the loss and the relative $L^2$ error at the same time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f1004e5-03a7-4dc6-aff7-2d05457b48ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reuse the merged DataFrame that contains both loss and l2 comparisons\n",
    "# First, make sure both baseline_loss and baseline_l2 are available\n",
    "\n",
    "# Step 1: Merge baseline loss and l2 together\n",
    "baseline_all = fgs[fgs['method'] == 'baseline'][['pde', 'G', 'depth', 'width', 'loss', 'l2']]\n",
    "baseline_all = baseline_all.rename(columns={'loss': 'baseline_loss', 'l2': 'baseline_l2'})\n",
    "\n",
    "# Step 2: Merge with the methods of interest\n",
    "fgs_comp_all = fgs[fgs['method'].isin(methods_of_interest)].copy()\n",
    "merged_all = pd.merge(\n",
    "    fgs_comp_all,\n",
    "    baseline_all,\n",
    "    on=['pde', 'G', 'depth', 'width'],\n",
    "    how='inner'\n",
    ")\n",
    "\n",
    "# Step 3: Compare both loss and l2\n",
    "merged_all['beats_both'] = (\n",
    "    (merged_all['loss'] < merged_all['baseline_loss']) &\n",
    "    (merged_all['l2'] < merged_all['baseline_l2'])\n",
    ")\n",
    "\n",
    "# Step 4: Group and count\n",
    "result_both = (\n",
    "    merged_all.groupby(['pde', 'method'])['beats_both']\n",
    "    .sum()\n",
    "    .reset_index(name='num_architectures')\n",
    ")\n",
    "\n",
    "result_both['percentage'] = 100*result_both['num_architectures']/num_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d95292-1d8d-4512-b8e0-9ce4b3d1c355",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result_both)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89a707f8-b538-48cc-b70d-f67fb20200ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
