{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(path: str) -> None:\n",
    "    with open(os.path.join(path, 'results.pkl'), 'rb') as f:\n",
    "        results = pickle.load(f)\n",
    "    momentum_error = results['momentum error']\n",
    "    traj_error = results['traj error'].reshape(-1, 2, 9)\n",
    "    # traj_error = results['traj error'].reshape(-1, 2, 9)[:, 0, :]\n",
    "\n",
    "    mean_weight = np.array([1e2, 1e3]).reshape(1, 2, 1)\n",
    "    traj_error = traj_error * mean_weight\n",
    "    return np.stack([\n",
    "        traj_error.mean(),\n",
    "        *traj_error.mean(axis=0).mean(axis=-1),\n",
    "        1e2 * results['energy error'].mean(),\n",
    "        1e4 * momentum_error\n",
    "    ])\n",
    "\n",
    "    return np.concatenate([\n",
    "        traj_error.mean(axis=0).mean(axis=-1),\n",
    "        np.array(results['energy error'].mean()).reshape(-1),\n",
    "        np.array(momentum_error).reshape(-1)\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filt(df: pd.DataFrame, loss: str) -> pd.DataFrame:\n",
    "    return df[df['loss'].map(lambda x: x.__contains__(loss))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = 'logs'\n",
    "n_search = 3 * 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "unfinished = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['jensen_0.05', 'jensen_0.01', 'jensen_0.005', 'jensen_0.001',\n",
       "       'jensen_0.0005', 'jensen_0.0001', 'momentum_0.05', 'momentum_0.5',\n",
       "       'momentum_1.0', 'implicit_energy_0.005', 'implicit_energy_0.0001',\n",
       "       'implicit_energy_0.001', 'momentum_0.1', 'momentum_0.01',\n",
       "       'implicit_energy_0.01', 'implicit_energy_0.05',\n",
       "       'implicit_energy_0.0005', 'momentum_0.005', 'naive',\n",
       "       'ablation_0.0001', 'ablation_0.0005', 'ablation_0.001',\n",
       "       'ablation_0.05', 'ablation_0.005', 'ablation_0.01'], dtype=object)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = defaultdict(list)\n",
    "for loss_setting in os.listdir(root):\n",
    "    count_found = 0\n",
    "    for model_name in os.listdir(os.path.join(root, loss_setting)):\n",
    "        for model_hyper in os.listdir(os.path.join(root, loss_setting, model_name)):\n",
    "            for file in os.listdir(os.path.join(root, loss_setting, model_name, model_hyper)):\n",
    "                if file.__contains__('ode_sample'):\n",
    "                    pkl_path = os.path.join(root, loss_setting, model_name, model_hyper, file)\n",
    "                    if not os.path.exists(os.path.join(pkl_path, 'results.pkl')):\n",
    "                        continue\n",
    "                    loss, model, hyper = model_hyper.split('---')[:-1]\n",
    "                    results['loss'].append(loss)\n",
    "                    results['model'].append(model)\n",
    "                    results['hyper'].append(hyper)\n",
    "                    performance = get_results(os.path.join(root, loss_setting, model_name, model_hyper, file))\n",
    "\n",
    "                    results['dynamics'].append(performance[0])\n",
    "                    results['traj'].append(performance[1])\n",
    "                    results['vel'].append(performance[2])\n",
    "                    results['energy'].append(performance[3])\n",
    "\n",
    "                    # for i in range(len(performance)):\n",
    "                    #     results[f'metric_{i}'].append(performance[i])\n",
    "\n",
    "                    count_found += 1\n",
    "    if count_found != n_search:\n",
    "        unfinished.append((loss_setting, model_name, n_search - count_found))\n",
    "\n",
    "\n",
    "df = pd.DataFrame(results).sort_values(by='dynamics')\n",
    "settings = df['loss'].unique()\n",
    "settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unfinished"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(map(lambda x: x[2], unfinished))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>326</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-3--t_embed_size-128</td>\n",
       "      <td>1.169006</td>\n",
       "      <td>1.607245</td>\n",
       "      <td>0.730767</td>\n",
       "      <td>0.506285</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>323</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-4--t_embed_size-128</td>\n",
       "      <td>1.213212</td>\n",
       "      <td>1.665904</td>\n",
       "      <td>0.760520</td>\n",
       "      <td>0.519817</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>317</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-5--t_embed_size-128</td>\n",
       "      <td>1.292567</td>\n",
       "      <td>1.782113</td>\n",
       "      <td>0.803022</td>\n",
       "      <td>0.453251</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>319</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-64--n_layers-5--t_embed_size-128</td>\n",
       "      <td>1.457015</td>\n",
       "      <td>2.035956</td>\n",
       "      <td>0.878074</td>\n",
       "      <td>0.490147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>265</th>\n",
       "      <td>jensen_0.01</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-256--n_layers-3--t_embed_size-128</td>\n",
       "      <td>1.587227</td>\n",
       "      <td>1.724174</td>\n",
       "      <td>1.450280</td>\n",
       "      <td>1.792860</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            loss    model                                          hyper  \\\n",
       "326  jensen_0.05  ParaGRU  hidden_size-128--n_layers-3--t_embed_size-128   \n",
       "323  jensen_0.05  ParaGRU  hidden_size-128--n_layers-4--t_embed_size-128   \n",
       "317  jensen_0.05  ParaGRU  hidden_size-128--n_layers-5--t_embed_size-128   \n",
       "319  jensen_0.05  ParaGRU   hidden_size-64--n_layers-5--t_embed_size-128   \n",
       "265  jensen_0.01  ParaGRU  hidden_size-256--n_layers-3--t_embed_size-128   \n",
       "\n",
       "     dynamics      traj       vel    energy  \n",
       "326  1.169006  1.607245  0.730767  0.506285  \n",
       "323  1.213212  1.665904  0.760520  0.519817  \n",
       "317  1.292567  1.782113  0.803022  0.453251  \n",
       "319  1.457015  2.035956  0.878074  0.490147  \n",
       "265  1.587227  1.724174  1.450280  1.792860  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>219</th>\n",
       "      <td>naive</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-256--n_layers-4--t_embed_size-128</td>\n",
       "      <td>2.608460</td>\n",
       "      <td>2.561369</td>\n",
       "      <td>2.655550</td>\n",
       "      <td>3.894118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>222</th>\n",
       "      <td>naive</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-256--n_layers-5--t_embed_size-128</td>\n",
       "      <td>2.620470</td>\n",
       "      <td>2.569592</td>\n",
       "      <td>2.671349</td>\n",
       "      <td>3.894496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>218</th>\n",
       "      <td>naive</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-512--n_layers-3--t_embed_size-128</td>\n",
       "      <td>2.678066</td>\n",
       "      <td>2.636862</td>\n",
       "      <td>2.719270</td>\n",
       "      <td>3.542751</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      loss    model                                          hyper  dynamics  \\\n",
       "219  naive  ParaGRU  hidden_size-256--n_layers-4--t_embed_size-128  2.608460   \n",
       "222  naive  ParaGRU  hidden_size-256--n_layers-5--t_embed_size-128  2.620470   \n",
       "218  naive  ParaGRU  hidden_size-512--n_layers-3--t_embed_size-128  2.678066   \n",
       "\n",
       "         traj       vel    energy  \n",
       "219  2.561369  2.655550  3.894118  \n",
       "222  2.569592  2.671349  3.894496  \n",
       "218  2.636862  2.719270  3.542751  "
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'naive').head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ablation_0.0001</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-512--n_layers-3--t_embed_size-128</td>\n",
       "      <td>2.624314</td>\n",
       "      <td>2.604886</td>\n",
       "      <td>2.643743</td>\n",
       "      <td>4.221988</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>ablation_0.0001</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-256--n_layers-5--t_embed_size-128</td>\n",
       "      <td>2.633160</td>\n",
       "      <td>2.609928</td>\n",
       "      <td>2.656392</td>\n",
       "      <td>4.119072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>ablation_0.0001</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-512--n_layers-4--t_embed_size-128</td>\n",
       "      <td>2.634805</td>\n",
       "      <td>2.613062</td>\n",
       "      <td>2.656548</td>\n",
       "      <td>3.873505</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               loss    model                                          hyper  \\\n",
       "9   ablation_0.0001  ParaGRU  hidden_size-512--n_layers-3--t_embed_size-128   \n",
       "12  ablation_0.0001  ParaGRU  hidden_size-256--n_layers-5--t_embed_size-128   \n",
       "14  ablation_0.0001  ParaGRU  hidden_size-512--n_layers-4--t_embed_size-128   \n",
       "\n",
       "    dynamics      traj       vel    energy  \n",
       "9   2.624314  2.604886  2.643743  4.221988  \n",
       "12  2.633160  2.609928  2.656392  4.119072  \n",
       "14  2.634805  2.613062  2.656548  3.873505  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'ablation').head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>113</th>\n",
       "      <td>momentum_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-512--n_layers-5--t_embed_size-128</td>\n",
       "      <td>2.196948</td>\n",
       "      <td>2.140928</td>\n",
       "      <td>2.252968</td>\n",
       "      <td>4.111693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>momentum_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-1024--n_layers-4--t_embed_size-128</td>\n",
       "      <td>2.472061</td>\n",
       "      <td>2.417980</td>\n",
       "      <td>2.526142</td>\n",
       "      <td>3.900317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>momentum_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-512--n_layers-4--t_embed_size-128</td>\n",
       "      <td>2.472658</td>\n",
       "      <td>2.418895</td>\n",
       "      <td>2.526420</td>\n",
       "      <td>6.897190</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              loss    model                                           hyper  \\\n",
       "113  momentum_0.05  ParaGRU   hidden_size-512--n_layers-5--t_embed_size-128   \n",
       "105  momentum_0.05  ParaGRU  hidden_size-1024--n_layers-4--t_embed_size-128   \n",
       "119  momentum_0.05  ParaGRU   hidden_size-512--n_layers-4--t_embed_size-128   \n",
       "\n",
       "     dynamics      traj       vel    energy  \n",
       "113  2.196948  2.140928  2.252968  4.111693  \n",
       "105  2.472061  2.417980  2.526142  3.900317  \n",
       "119  2.472658  2.418895  2.526420  6.897190  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'momentum_0.05').head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>326</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-3--t_embed_size-128</td>\n",
       "      <td>1.169006</td>\n",
       "      <td>1.607245</td>\n",
       "      <td>0.730767</td>\n",
       "      <td>0.506285</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>323</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-4--t_embed_size-128</td>\n",
       "      <td>1.213212</td>\n",
       "      <td>1.665904</td>\n",
       "      <td>0.760520</td>\n",
       "      <td>0.519817</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>317</th>\n",
       "      <td>jensen_0.05</td>\n",
       "      <td>ParaGRU</td>\n",
       "      <td>hidden_size-128--n_layers-5--t_embed_size-128</td>\n",
       "      <td>1.292567</td>\n",
       "      <td>1.782113</td>\n",
       "      <td>0.803022</td>\n",
       "      <td>0.453251</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            loss    model                                          hyper  \\\n",
       "326  jensen_0.05  ParaGRU  hidden_size-128--n_layers-3--t_embed_size-128   \n",
       "323  jensen_0.05  ParaGRU  hidden_size-128--n_layers-4--t_embed_size-128   \n",
       "317  jensen_0.05  ParaGRU  hidden_size-128--n_layers-5--t_embed_size-128   \n",
       "\n",
       "     dynamics      traj       vel    energy  \n",
       "326  1.169006  1.607245  0.730767  0.506285  \n",
       "323  1.213212  1.665904  0.760520  0.519817  \n",
       "317  1.292567  1.782113  0.803022  0.453251  "
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'jensen_0.05').head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>dynamics</th>\n",
       "      <th>traj</th>\n",
       "      <th>vel</th>\n",
       "      <th>energy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>246</th>\n",
       "      <td>implicit_energy_0.005</td>\n",
       "      <td>ParaPhyGRU</td>\n",
       "      <td>hidden_size-512--n_layers-4--t_embed_size-128</td>\n",
       "      <td>2.349209</td>\n",
       "      <td>2.274541</td>\n",
       "      <td>2.423877</td>\n",
       "      <td>4.022399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>251</th>\n",
       "      <td>implicit_energy_0.005</td>\n",
       "      <td>ParaPhyGRU</td>\n",
       "      <td>hidden_size-512--n_layers-3--t_embed_size-128</td>\n",
       "      <td>2.578503</td>\n",
       "      <td>2.533568</td>\n",
       "      <td>2.623439</td>\n",
       "      <td>3.809135</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>244</th>\n",
       "      <td>implicit_energy_0.005</td>\n",
       "      <td>ParaPhyGRU</td>\n",
       "      <td>hidden_size-1024--n_layers-3--t_embed_size-128</td>\n",
       "      <td>2.590293</td>\n",
       "      <td>2.506814</td>\n",
       "      <td>2.673772</td>\n",
       "      <td>5.213148</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      loss       model  \\\n",
       "246  implicit_energy_0.005  ParaPhyGRU   \n",
       "251  implicit_energy_0.005  ParaPhyGRU   \n",
       "244  implicit_energy_0.005  ParaPhyGRU   \n",
       "\n",
       "                                              hyper  dynamics      traj  \\\n",
       "246   hidden_size-512--n_layers-4--t_embed_size-128  2.349209  2.274541   \n",
       "251   hidden_size-512--n_layers-3--t_embed_size-128  2.578503  2.533568   \n",
       "244  hidden_size-1024--n_layers-3--t_embed_size-128  2.590293  2.506814   \n",
       "\n",
       "          vel    energy  \n",
       "246  2.423877  4.022399  \n",
       "251  2.623439  3.809135  \n",
       "244  2.673772  5.213148  "
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filt(df, 'implicit_energy_0.005').head(3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
