{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from torch import Tensor\n",
    "\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = 'logs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adv_1d_loss(pred: Tensor, dx: Tensor, dt: Tensor) -> Tensor:\n",
    "    beta = 0.1\n",
    "\n",
    "    du_t = (pred[:, :, 2:] - pred[:, :, :-2]) / (2*dt)\n",
    "    du_x = (pred[:, 2:] - pred[:, :-2]) / (2*dx)\n",
    "    rmse_de_loss = (du_t[:, 1:-1] + beta * du_x[:, :, 1:-1]).square().mean().sqrt()\n",
    "    return rmse_de_loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adv_1d_loss(pred: Tensor, dx = 1 / 256, dt = 0.05) -> Tensor:\n",
    "    beta = 0.1\n",
    "\n",
    "    du_t = (pred[:, :, 2:] - pred[:, :, :-2]) / (2*dt)\n",
    "    du_x = (pred[:, 2:] - pred[:, :-2]) / (2*dx)\n",
    "    rmse_de_loss = (du_t[:, 1:-1] + beta * du_x[:, :, 1:-1]).square().mean().sqrt()\n",
    "    return rmse_de_loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['pde_1.0', 'pde_0.1', 'pde_0.01', 'pde_0.001', 'naive'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = defaultdict(list)\n",
    "for model_name in os.listdir(root):\n",
    "    for model_hyper in os.listdir(os.path.join(root, model_name)):\n",
    "        for file in os.listdir(os.path.join(root, model_name, model_hyper)):\n",
    "            if file.__contains__('predict_on_xt'):\n",
    "                for pkl_file in os.listdir(os.path.join(root, model_name, model_hyper, file)):\n",
    "                    with open(os.path.join(root, model_name, model_hyper, file, pkl_file), 'rb') as f:\n",
    "                        data_x = torch.from_numpy(pickle.load(f))\n",
    "                    pde_loss = adv_1d_loss(data_x)\n",
    "\n",
    "                    loss, model, hyper = model_hyper.split('---')[:-1]\n",
    "                    results['loss'].append(loss)\n",
    "                    results['model'].append(model)\n",
    "                    results['hyper'].append(hyper)\n",
    "                    results['performance'].append(float(pkl_file.split('_')[-1][:-4]))\n",
    "\n",
    "df = pd.DataFrame(results).sort_values(by='performance')\n",
    "settings = df['loss'].unique()\n",
    "settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>performance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>pde_1.0</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-5</td>\n",
       "      <td>0.23046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>pde_1.0</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-4</td>\n",
       "      <td>0.23048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>pde_1.0</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-3</td>\n",
       "      <td>0.23053</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       loss model                        hyper  performance\n",
       "42  pde_1.0   GRU  hidden_size-128--n_layers-5      0.23046\n",
       "17  pde_1.0   GRU  hidden_size-128--n_layers-4      0.23048\n",
       "29  pde_1.0   GRU  hidden_size-128--n_layers-3      0.23053"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['loss'] == 'pde_1.0'].head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>hyper</th>\n",
       "      <th>performance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>naive</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-4</td>\n",
       "      <td>0.23806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>naive</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-3</td>\n",
       "      <td>0.23969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>naive</td>\n",
       "      <td>GRU</td>\n",
       "      <td>hidden_size-128--n_layers-5</td>\n",
       "      <td>0.24008</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     loss model                        hyper  performance\n",
       "31  naive   GRU  hidden_size-128--n_layers-4      0.23806\n",
       "1   naive   GRU  hidden_size-128--n_layers-3      0.23969\n",
       "7   naive   GRU  hidden_size-128--n_layers-5      0.24008"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['loss'] == 'naive'].head(3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
