{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = 'logs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filt(df: pd.DataFrame, loss: str) -> pd.DataFrame:\n",
    "    return df[df['loss'].map(lambda x: x.__contains__(loss))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = defaultdict(list)\n",
    "for model_name in os.listdir(root):\n",
    "    for model_hyper in os.listdir(os.path.join(root, model_name)):\n",
    "        for file in os.listdir(os.path.join(root, model_name, model_hyper)):\n",
    "            if file.__contains__('sample'):\n",
    "                pkl_path = os.path.join(root, model_name, model_hyper, file, 'results.pkl')\n",
    "                if not os.path.exists(pkl_path):\n",
    "                    continue\n",
    "\n",
    "                loss, model, hyper = model_hyper.split('---')[1:-1]\n",
    "                results['loss'].append(loss)\n",
    "                results['model'].append(model)\n",
    "                results['hyper'].append(hyper)\n",
    "\n",
    "                results_dict = pickle.load(open(pkl_path, 'rb'))\n",
    "                results['pred error'].append(results_dict['pred error'].mean())\n",
    "                results['motum error'].append(results_dict['motum error'])\n",
    "                results['energy error'].append(results_dict['energy error'])\n",
    "                \n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "df = df.sort_values(by='pred error')\n",
    "df['motum error'] = 10.0 * df['motum error']\n",
    "df['pred error'] = 1e2 * df['pred error']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>pred error</th>\n",
       "      <th>motum error</th>\n",
       "      <th>energy error</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>naive</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.192987</td>\n",
       "      <td>5.351126</td>\n",
       "      <td>1.089164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>naive</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-256--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.195068</td>\n",
       "      <td>5.346899</td>\n",
       "      <td>1.080578</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>naive</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-1024--t_embed_size-64--gnn_hid...</td>\n",
       "      <td>5.203237</td>\n",
       "      <td>5.396400</td>\n",
       "      <td>1.089882</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     loss     model                                              hyper  \\\n",
       "14  naive  EGNN_GRU  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...   \n",
       "28  naive  EGNN_GRU  rnn_hidden_size-256--t_embed_size-64--gnn_hidd...   \n",
       "4   naive  EGNN_GRU  rnn_hidden_size-1024--t_embed_size-64--gnn_hid...   \n",
       "\n",
       "    pred error  motum error  energy error  \n",
       "14    5.192987     5.351126      1.089164  \n",
       "28    5.195068     5.346899      1.080578  \n",
       "4     5.203237     5.396400      1.089882  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['loss'] == 'naive'].head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>pred error</th>\n",
       "      <th>motum error</th>\n",
       "      <th>energy error</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>momentum_0.5</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-256--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.091948</td>\n",
       "      <td>0.368733</td>\n",
       "      <td>0.744856</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>momentum_0.5</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.099092</td>\n",
       "      <td>0.433579</td>\n",
       "      <td>0.765258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>momentum_0.5</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-1024--t_embed_size-64--gnn_hid...</td>\n",
       "      <td>5.141414</td>\n",
       "      <td>0.557017</td>\n",
       "      <td>0.782539</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            loss     model                                              hyper  \\\n",
       "6   momentum_0.5  EGNN_GRU  rnn_hidden_size-256--t_embed_size-64--gnn_hidd...   \n",
       "25  momentum_0.5  EGNN_GRU  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...   \n",
       "18  momentum_0.5  EGNN_GRU  rnn_hidden_size-1024--t_embed_size-64--gnn_hid...   \n",
       "\n",
       "    pred error  motum error  energy error  \n",
       "6     5.091948     0.368733      0.744856  \n",
       "25    5.099092     0.433579      0.765258  \n",
       "18    5.141414     0.557017      0.782539  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['loss'] == 'momentum_0.5'].head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>pred error</th>\n",
       "      <th>motum error</th>\n",
       "      <th>energy error</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>implicit_energy_0.1</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-256--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.161522</td>\n",
       "      <td>5.303237</td>\n",
       "      <td>1.054893</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>implicit_energy_0.1</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-1024--t_embed_size-64--gnn_hid...</td>\n",
       "      <td>5.180918</td>\n",
       "      <td>5.390288</td>\n",
       "      <td>1.087981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>implicit_energy_0.1</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.187491</td>\n",
       "      <td>5.360503</td>\n",
       "      <td>1.090479</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   loss     model  \\\n",
       "34  implicit_energy_0.1  EGNN_GRU   \n",
       "21  implicit_energy_0.1  EGNN_GRU   \n",
       "15  implicit_energy_0.1  EGNN_GRU   \n",
       "\n",
       "                                                hyper  pred error  \\\n",
       "34  rnn_hidden_size-256--t_embed_size-64--gnn_hidd...    5.161522   \n",
       "21  rnn_hidden_size-1024--t_embed_size-64--gnn_hid...    5.180918   \n",
       "15  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...    5.187491   \n",
       "\n",
       "    motum error  energy error  \n",
       "34     5.303237      1.054893  \n",
       "21     5.390288      1.087981  \n",
       "15     5.360503      1.090479  "
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'implicit_energy_0.1').head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>pred error</th>\n",
       "      <th>motum error</th>\n",
       "      <th>energy error</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>jensen_0.0001</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.208777</td>\n",
       "      <td>5.388190</td>\n",
       "      <td>1.097947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>jensen_0.001</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.212376</td>\n",
       "      <td>5.392790</td>\n",
       "      <td>1.101861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>jensen_0.1</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-512--t_embed_size-64--gnn_hidd...</td>\n",
       "      <td>5.212894</td>\n",
       "      <td>5.389690</td>\n",
       "      <td>1.100138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>jensen_0.01</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-1024--t_embed_size-64--gnn_hid...</td>\n",
       "      <td>5.213667</td>\n",
       "      <td>5.420824</td>\n",
       "      <td>1.105263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>jensen_0.1</td>\n",
       "      <td>EGNN_GRU</td>\n",
       "      <td>rnn_hidden_size-1024--t_embed_size-64--gnn_hid...</td>\n",
       "      <td>5.214410</td>\n",
       "      <td>5.413630</td>\n",
       "      <td>1.104956</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             loss     model  \\\n",
       "33  jensen_0.0001  EGNN_GRU   \n",
       "17   jensen_0.001  EGNN_GRU   \n",
       "13     jensen_0.1  EGNN_GRU   \n",
       "2     jensen_0.01  EGNN_GRU   \n",
       "0      jensen_0.1  EGNN_GRU   \n",
       "\n",
       "                                                hyper  pred error  \\\n",
       "33  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...    5.208777   \n",
       "17  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...    5.212376   \n",
       "13  rnn_hidden_size-512--t_embed_size-64--gnn_hidd...    5.212894   \n",
       "2   rnn_hidden_size-1024--t_embed_size-64--gnn_hid...    5.213667   \n",
       "0   rnn_hidden_size-1024--t_embed_size-64--gnn_hid...    5.214410   \n",
       "\n",
       "    motum error  energy error  \n",
       "33     5.388190      1.097947  \n",
       "17     5.392790      1.101861  \n",
       "13     5.389690      1.100138  \n",
       "2      5.420824      1.105263  \n",
       "0      5.413630      1.104956  "
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'jensen').head(5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
