{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae493094",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from utils.misc import get_output_dir,read_pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c313c90e",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_dir = get_output_dir()\n",
    "\n",
    "num_params = 50\n",
    "if num_params == 50:\n",
    "    experiment_folder = out_dir / \"ice_experiment\"\n",
    "    methods = [\"FNO_FMPE_always_equispaced_False\",\"spectral_NPE\",\"baseline_raw_FMPE\",\"baseline_spectral_FMPE\",\"raw_NPE\"]\n",
    "elif num_params == 500:\n",
    "    experiment_folder = out_dir / \"ice_experiment_500\"\n",
    "    methods = [\"FNO_FMPE_always_equispaced_False\",\"spectral_NPE\",\"baseline_spectral_FMPE\",\"raw_NPE\"]\n",
    "else:\n",
    "    raise ValueError(\"num_params should be 50 or 500\")\n",
    "\n",
    "\n",
    "nsims = [1_000,10_000, 100_000]\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for method in methods:\n",
    "    for nsim in nsims:\n",
    "        for run in range(1,4):\n",
    "            print(method,nsim,run)\n",
    "\n",
    "            tarp_res = read_pickle(experiment_folder / method / f\"num_sim_{nsim}_run_{run}\" / f\"tarp_results.pkl\")\n",
    "            tarp_absolute_atcs = tarp_res[\"absolute_atcs\"]\n",
    "            sbc_res = read_pickle(experiment_folder / method / f\"num_sim_{nsim}_run_{run}\" / f\"sbc_results.pkl\")\n",
    "            sbc_absolute_atcs = sbc_res[\"absolute_atcs\"]\n",
    "            predictive_mse_res = read_pickle(experiment_folder / method / f\"num_sim_{nsim}_run_{run}\" / f\"predictive_check_results.pkl\")\n",
    "            predictive_mse = predictive_mse_res[\"mses\"]/x.shape[-1]\n",
    "            predictive_mse_real_data_res = read_pickle(experiment_folder / method / f\"num_sim_{nsim}_run_{run}\" / f\"real_layers_predictive_check_results.pkl\")\n",
    "            predictive_mse_real_data = predictive_mse_real_data_res[\"mses\"]\n",
    "            random_seed_path = experiment_folder / method / f\"num_sim_{nsim}_run_{run}\" / f\"random_seed.csv\"\n",
    "            random_seed = int(np.loadtxt(random_seed_path, delimiter=\",\"))\n",
    "            # print(random_seeds[method][nsim][run-1])\n",
    "            print(f\"tarp: {tarp_absolute_atcs}\")\n",
    "            print(f\"sbc: {sbc_absolute_atcs.shape}\")\n",
    "            print(f\"predictive_mse: {predictive_mse.flatten().shape}\")\n",
    "            print(f\"random_seed: {random_seed}\")\n",
    "\n",
    "            df = pd.concat([df, pd.DataFrame({\n",
    "                \"method\": method,\n",
    "                \"nsim\": nsim,\n",
    "                \"random_seed\": random_seed,\n",
    "                \"tarps\": tarp_absolute_atcs,\n",
    "                \"sbcs\": [sbc_absolute_atcs.cpu().numpy()],\n",
    "                \"predictive_mses\": [predictive_mse.cpu().numpy()],\n",
    "                \"predictive_mses_real_data\": [predictive_mse_real_data.cpu().numpy()],\n",
    "            },)], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52ca8d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85c63b39",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(experiment_folder / \"summary.csv\")\n",
    "df.to_pickle(experiment_folder / \"summary.pkl\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
