{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from collections import defaultdict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = 'logs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 128, 128)\n"
     ]
    }
   ],
   "source": [
    "sample_path = 'logs/x0---pde_0.001---hidden_size-128---2024_07_07__01_40_43/ode_sample-2024_07_12__00_40_45/samples_all_88.00717163085938.pkl'\n",
    "\n",
    "with open(sample_path, 'rb') as f:\n",
    "    print(pickle.load(f).shape) \n",
    "\n",
    "num_samples = 1000 * 128 * 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>loss</th>\n",
       "      <th>metric</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>x0</td>\n",
       "      <td>pde_0.001</td>\n",
       "      <td>2.174243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>x0</td>\n",
       "      <td>naive</td>\n",
       "      <td>2.261169</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>x0</td>\n",
       "      <td>pde_0.01</td>\n",
       "      <td>2.840474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>noise</td>\n",
       "      <td>naive</td>\n",
       "      <td>7.268711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>noise</td>\n",
       "      <td>pde_0.01</td>\n",
       "      <td>8.574709</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>noise</td>\n",
       "      <td>pde_0.001</td>\n",
       "      <td>10.438120</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   model       loss     metric\n",
       "1     x0  pde_0.001   2.174243\n",
       "5     x0      naive   2.261169\n",
       "0     x0   pde_0.01   2.840474\n",
       "3  noise      naive   7.268711\n",
       "2  noise   pde_0.01   8.574709\n",
       "4  noise  pde_0.001  10.438120"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = defaultdict(list)\n",
    "for model_folder in os.listdir(root):\n",
    "    for sample_folder in os.listdir(os.path.join(root, model_folder)):\n",
    "        if sample_folder.__contains__('sample'):\n",
    "            sample_pkls = os.listdir(os.path.join(root, model_folder, sample_folder))\n",
    "            if len(sample_pkls) == 1:\n",
    "                model_pred, loss = model_folder.split('---')[:2]\n",
    "                df['model'].append(model_pred)\n",
    "                df['loss'].append(loss)\n",
    "                df['metric'].append(\n",
    "                    100 * np.sqrt(np.square(float(sample_pkls[0].split('_')[-1][:-4])) / num_samples)\n",
    "                )\n",
    "\n",
    "pd.DataFrame(df).sort_values(by='metric')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
