{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "29d0d70e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "parent_dir = os.path.dirname(os.getcwd())\n",
    "sys.path.append(parent_dir)\n",
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "83dc0ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_dir = 'eval_results'\n",
    "run_name = 'kdv_disindy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6b041949",
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_results(run_name, min_seed=0, max_seed=100, mse_multiplier=1.0):\n",
    "    directory = os.path.join(parent_dir, result_dir, run_name)\n",
    "    cf, mse, cf_all, mse_all = [], [], [], []\n",
    "    coefs = []\n",
    "    for filename in os.listdir(directory):\n",
    "        if filename.endswith('.npz'):\n",
    "            file_path = os.path.join(directory, filename)\n",
    "            seed = int(filename.split('.')[0][4:])\n",
    "            if seed >= max_seed or seed < min_seed:\n",
    "                continue\n",
    "            res = np.load(file_path)\n",
    "            cf.append(res['correct_form'])\n",
    "            mse.append(res['mse'])\n",
    "            cf_all.append(res['correct_form_all'])\n",
    "            mse_all.append(res['mse_all'])\n",
    "            coefs.append(res['coefficients'])\n",
    "    print(f'Loaded results from {len(cf)} runs.')\n",
    "    cf = np.stack(cf)\n",
    "    cf_sum = np.sum(cf, axis=0).astype(int)\n",
    "    for i, each in enumerate(cf_sum):\n",
    "        print(f'Equation {i} success rate = {cf_sum[i]}/{cf.shape[0]}')\n",
    "    cf_all_sum = np.sum(cf_all).astype(int)\n",
    "    print(f'Joint success rate = {cf_all_sum}/{cf.shape[0]}')\n",
    "    mse = np.stack(np.sqrt(mse))\n",
    "    mse[np.isnan(mse)] = np.max(mse[~np.isnan(mse)])\n",
    "    for i in range(mse.shape[1]):\n",
    "        mse_valid = np.mean(mse[np.where(cf[:, i]), i])\n",
    "        std = np.std(mse[np.where(cf[:, i]), i])\n",
    "        mse_any = np.mean(mse[:, i])\n",
    "        std_any = np.std(mse[:, i])\n",
    "        mse_valid *= mse_multiplier\n",
    "        std *= mse_multiplier\n",
    "        mse_any *= mse_multiplier\n",
    "        std_any *= mse_multiplier\n",
    "        print(f'Equation {i} RMSE = {mse_valid:.8f} ({std:.8f})')\n",
    "        print(f'Equation {i} RMSE (any) = {mse_any:.8f} ({std_any:.8f})')\n",
    "    mse_all = np.stack(np.sqrt(mse_all))\n",
    "    mse_all[np.isnan(mse_all)] = np.max(mse_all[~np.isnan(mse_all)])\n",
    "    mse_all_valid = np.mean(mse_all[np.where(cf_all)])\n",
    "    std = np.std(mse_all[np.where(cf_all)])\n",
    "    mse_all_any = np.mean(mse_all)\n",
    "    std_any = np.std(mse_all)\n",
    "    mse_all_valid *= mse_multiplier\n",
    "    std *= mse_multiplier\n",
    "    mse_all_any *= mse_multiplier\n",
    "    std_any *= mse_multiplier\n",
    "    print(f'All equations RMSE = {mse_all_valid:.8f} ({std:.8f})')\n",
    "    print(f'All equations RMSE (any) = {mse_all_any:.8f} ({std_any:.8f})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b361732d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded results from 50 runs.\n",
      "Equation 0 success rate = 50/50\n",
      "Joint success rate = 50/50\n",
      "Equation 0 RMSE = 0.02713786 (0.02436040)\n",
      "Equation 0 RMSE (any) = 0.02713786 (0.02436040)\n",
      "All equations RMSE = 0.02713786 (0.02436040)\n",
      "All equations RMSE (any) = 0.02713786 (0.02436040)\n"
     ]
    }
   ],
   "source": [
    "aggregate_results(run_name, max_seed=50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
