{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75f50336",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "from utils.misc import read_pickle, get_data_dir, get_output_dir\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "out_dir = get_output_dir()\n",
    "data_dir = get_data_dir()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feed25a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "nsims = [1_000, 10_000, 100_000]\n",
    "eval_points = [20, 40]\n",
    "init_conditions = [\"True\", \"False\"]\n",
    "out_dir = get_output_dir()\n",
    "experiment_folder = Path(out_dir/\"sir_experiment\" / \"FNO_FMPE\")\n",
    "methods = [\"fno\",\"simformer\"]\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for method in methods:\n",
    "    for evals in eval_points:\n",
    "        for init_cond in init_conditions:\n",
    "            for nsim in nsims:\n",
    "\n",
    "                tarp_res = read_pickle(experiment_folder / f\"num_sim_{nsim}_eval_points_{evals}_initial_conditions_{init_cond}\" / f\"{method}_tarp_results.pkl\")\n",
    "                tarp_absolute_atcs = tarp_res[\"absolute_atcs\"]\n",
    "                sbc_res = read_pickle(experiment_folder / f\"num_sim_{nsim}_eval_points_{evals}_initial_conditions_{init_cond}\"  / f\"{method}_sbc_results.pkl\")\n",
    "                sbc_absolute_atcs = sbc_res[\"absolute_atcs\"]\n",
    "                predictive_mse_res = read_pickle(experiment_folder / f\"num_sim_{nsim}_eval_points_{evals}_initial_conditions_{init_cond}\"  / f\"{method}_predictive_check_results.pkl\")\n",
    "                predictive_mse = predictive_mse_res[\"mses\"]\n",
    "                print(\"predictive mse shape: \", predictive_mse.shape)\n",
    "                random_seed_path = experiment_folder / f\"num_sim_{nsim}_eval_points_{evals}_initial_conditions_{init_cond}\"  / f\"random_seed.csv\"\n",
    "                random_seed = int(np.loadtxt(random_seed_path, delimiter=\",\"))\n",
    "                # print(random_seeds[method][nsim][run-1])\n",
    "                print(f\"method: {method}, num_sim: {nsim}, eval_points: {evals}, initial_conditions: {init_cond}\")\n",
    "                print(f\"tarp: {tarp_absolute_atcs}\")\n",
    "                print(f\"sbc: {sbc_absolute_atcs.shape}\")\n",
    "                print(f\"predictive_mse: {predictive_mse.flatten().mean()}\")\n",
    "                print(f\"random_seed: {random_seed}\")\n",
    "\n",
    "                df = pd.concat([df, pd.DataFrame({\n",
    "                    \"method\": method,\n",
    "                    \"eval_num\": evals,\n",
    "                    \"simformer_initial_condition\": init_cond,\n",
    "                    \"nsim\": nsim,\n",
    "                    \"random_seed\": random_seed,\n",
    "                    \"tarps\": tarp_absolute_atcs,\n",
    "                    \"sbcs\": [sbc_absolute_atcs.cpu().numpy()],\n",
    "                    \"predictive_mses\": [predictive_mse.cpu().numpy()],\n",
    "                },)], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89483b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "778797cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(experiment_folder / \"summary.csv\", index=False)\n",
    "df.to_pickle(experiment_folder / \"summary.pkl\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
