{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Produces qualitative sample rollouts of the trained emulators in 1D\n",
    "\n",
    "First melt the sample rollouts (not done by default)\n",
    "\n",
    "```bash\n",
    "python run.py experiments/broad_comparison_1d.py --dont_melt_loss --dont_melt_metrics --melt_sample_rollouts\n",
    "```\n",
    "\n",
    "Then produce reference rollouts (script needs to be executed out of this folder)\n",
    "\n",
    "```bash\n",
    "python get_ref_sample_rollouts.py\n",
    "```\n",
    "\n",
    "(Important: Has to be executed on the same machine as the training was done on and with the same JAX version to ensure that the same initial conditions are drawn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "# APEBench might otherwise set this to \"gpu\"\n",
    "jax.config.update(\"jax_platform_name\", \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../../apebench/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import apebench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"img/sample_rollouts/\", exist_ok=True)\n",
    "os.makedirs(\"img/sample_rollouts/across_seeds\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Should load in ~10 sec\n",
    "sample_rollout_data = pd.read_csv(\"../../melted/broad_comparison_1d/sample_rollout.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Takes ~30 sec\n",
    "sample_rollout_data[\"sample_rollout_numpy\"] = sample_rollout_data[\n",
    "    \"sample_rollout\"\n",
    "].apply(\n",
    "    lambda x: np.array(\n",
    "        json.loads(re.sub(r\"\\bnan\\b\", \"NaN\", re.sub(r\"\\binf\\b\", \"Infinity\", x)))\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_s = []\n",
    "for scene in [\n",
    "    \"1d_diff_disp\",\n",
    "    \"1d_diff_burgers\",\n",
    "    \"1d_diff_kdv\",\n",
    "    \"1d_diff_ks_cons\",\n",
    "    \"1d_diff_ks\",\n",
    "]:\n",
    "    ref = np.load(f\"ref_sample_rollouts/{scene}.npy\")[0]\n",
    "    for s in sample_rollout_data.seed.unique():\n",
    "        data = pd.DataFrame(\n",
    "            {\n",
    "                \"seed\": s,\n",
    "                \"net\": \"Ref\",\n",
    "                \"scenario\": scene,\n",
    "                \"task\": \"predict\",\n",
    "                \"train\": \"one\",\n",
    "                \"scenario_kwargs\": [\n",
    "                    {},\n",
    "                ],\n",
    "                \"sample_index\": 0,\n",
    "                \"sample_rollout\": [\n",
    "                    [],\n",
    "                ],\n",
    "                \"sample_rollout_numpy\": [\n",
    "                    ref,\n",
    "                ],\n",
    "            }\n",
    "        )\n",
    "        data_s.append(data)\n",
    "\n",
    "ref_data_merged = pd.concat(data_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = pd.concat([ref_data_merged, sample_rollout_data])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = sample_rollout_data.groupby(\n",
    "    [\"seed\", \"sample_index\", \"scenario\"],\n",
    "    group_keys=False,\n",
    ").apply(\n",
    "    lambda df: df.assign(\n",
    "        sample_rollout_numpy_diff=lambda x: x[\"sample_rollout_numpy\"]\n",
    "        - x.query(\"net == 'Ref'\")[\"sample_rollout_numpy\"].values[0:1]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assign_order(df: pd.DataFrame, column_name: str, order: list):\n",
    "    changing_dict = {}\n",
    "    changing_dict[column_name] = lambda data: pd.Categorical(\n",
    "        data[column_name], categories=order, ordered=True\n",
    "    )\n",
    "    df = df.assign(**changing_dict)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"net\",\n",
    "    [\n",
    "        \"Ref\",\n",
    "        \"Conv;34;10;relu\",\n",
    "        \"UNet;12;2;relu\",\n",
    "        \"Res;26;8;relu\",\n",
    "        \"FNO;12;18;4;gelu\",\n",
    "        \"Dil;2;32;2;relu\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"scenario\",\n",
    "    [\n",
    "        \"1d_diff_disp\",\n",
    "        \"1d_diff_burgers\",\n",
    "        \"1d_diff_kdv\",\n",
    "        \"1d_diff_ks_cons\",\n",
    "        \"1d_diff_ks\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data[\"sample_rollout_numpy_diff_abs\"] = sample_rollout_data[\n",
    "    \"sample_rollout_numpy_diff\"\n",
    "].apply(np.abs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = sample_rollout_data.sort_values(\n",
    "    [\"scenario\", \"net\", \"seed\", \"sample_index\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# facet = sns.FacetGrid(\n",
    "#     sample_rollout_data.query(\"seed == 0\"),\n",
    "#     row=\"net\",\n",
    "#     col=\"scenario\",\n",
    "#     aspect=1.5,\n",
    "# )\n",
    "# facet.map(\n",
    "#     lambda x, **kwargs: plt.imshow(\n",
    "#         x.values[0][:, 0, :].T, cmap=\"RdBu_r\", vmin=-1, vmax=1, aspect=\"auto\"\n",
    "#     ),\n",
    "#     \"sample_rollout_numpy\",\n",
    "# )\n",
    "\n",
    "# facet.figure.subplots_adjust(top=0.9)\n",
    "# facet.figure.suptitle(\"Sample Rollouts, seed=0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [04:33<00:00,  5.48s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    facet = sns.FacetGrid(\n",
    "        sample_rollout_data.query(f\"seed == {seed}\"),\n",
    "        row=\"net\",\n",
    "        col=\"scenario\",\n",
    "        aspect=1.5,\n",
    "    )\n",
    "    facet.map(\n",
    "        lambda x, **kwargs: (\n",
    "            plt.imshow(\n",
    "                x.values[0][:, 0, :].T, cmap=\"RdBu_r\", vmin=-1, vmax=1, aspect=\"auto\"\n",
    "            )\n",
    "        ),\n",
    "        \"sample_rollout_numpy\",\n",
    "    )\n",
    "\n",
    "    facet.figure.subplots_adjust(top=0.9)\n",
    "    facet.figure.suptitle(f\"Sample Rollouts, seed={seed}\")\n",
    "\n",
    "    # Save figure\n",
    "    plt.savefig(f\"img/sample_rollouts/sample_rollouts_seed_{seed}.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [04:29<00:00,  5.40s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    facet = sns.FacetGrid(\n",
    "        sample_rollout_data.query(f\"seed == {seed}\").query(\"net != 'Ref'\"),\n",
    "        row=\"net\",\n",
    "        col=\"scenario\",\n",
    "        aspect=1.5,\n",
    "    )\n",
    "    facet.map(\n",
    "        lambda x, **kwargs: (\n",
    "            plt.imshow(\n",
    "                x.values[0][:, 0, :].T, cmap=\"RdBu_r\", vmin=-1, vmax=1, aspect=\"auto\"\n",
    "            )\n",
    "        ),\n",
    "        \"sample_rollout_numpy_diff\",\n",
    "    )\n",
    "\n",
    "    facet.figure.subplots_adjust(top=0.9)\n",
    "    facet.figure.suptitle(f\"Sample Rollouts Diff, seed={seed}\")\n",
    "\n",
    "    # Save figure\n",
    "    plt.savefig(f\"img/sample_rollouts/sample_rollouts_diff_seed_{seed}.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [04:23<00:00,  5.27s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    facet = sns.FacetGrid(\n",
    "        sample_rollout_data.query(f\"seed == {seed}\").query(\"net != 'Ref'\"),\n",
    "        row=\"net\",\n",
    "        col=\"scenario\",\n",
    "        aspect=1.5,\n",
    "    )\n",
    "    facet.map(\n",
    "        lambda x, **kwargs: (\n",
    "            plt.imshow(\n",
    "                x.values[0][:, 0, :].T, cmap=\"RdBu_r\", vmin=-1, vmax=1, aspect=\"auto\"\n",
    "            )\n",
    "        ),\n",
    "        \"sample_rollout_numpy_diff_abs\",\n",
    "    )\n",
    "\n",
    "    facet.figure.subplots_adjust(top=0.9)\n",
    "    facet.figure.suptitle(f\"Sample Rollouts Abs-Diff, seed={seed}\")\n",
    "\n",
    "    # Save figure\n",
    "    plt.savefig(f\"img/sample_rollouts/sample_rollouts_abs_diff_seed_{seed}.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [03:54<00:00, 46.87s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for scenario in tqdm(sample_rollout_data.scenario.unique()):\n",
    "    for net in sample_rollout_data.net.unique():\n",
    "        if net == \"Ref\":\n",
    "            continue\n",
    "        facet = sns.FacetGrid(\n",
    "            sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "                f\"net == '{net}'\"\n",
    "            ),\n",
    "            col=\"seed\",\n",
    "            col_wrap=5,\n",
    "            aspect=1.5,\n",
    "        )\n",
    "        facet.map(\n",
    "            lambda x, **kwargs: (\n",
    "                plt.imshow(\n",
    "                    x.values[0][:, 0, :].T,\n",
    "                    cmap=\"RdBu_r\",\n",
    "                    vmin=-1,\n",
    "                    vmax=1,\n",
    "                    aspect=\"auto\",\n",
    "                )\n",
    "            ),\n",
    "            \"sample_rollout_numpy\",\n",
    "        )\n",
    "\n",
    "        facet.figure.subplots_adjust(top=0.9)\n",
    "        facet.figure.suptitle(f\"Sample Rollouts, {scenario}, {net}\")\n",
    "\n",
    "        # Save figure\n",
    "        plt.savefig(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_{scenario}_{net}.pdf\"\n",
    "        )\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [03:59<00:00, 47.98s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for scenario in tqdm(sample_rollout_data.scenario.unique()):\n",
    "    for net in sample_rollout_data.net.unique():\n",
    "        if net == \"Ref\":\n",
    "            continue\n",
    "        facet = sns.FacetGrid(\n",
    "            sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "                f\"net == '{net}'\"\n",
    "            ),\n",
    "            col=\"seed\",\n",
    "            col_wrap=5,\n",
    "            aspect=1.5,\n",
    "        )\n",
    "        facet.map(\n",
    "            lambda x, **kwargs: (\n",
    "                plt.imshow(\n",
    "                    x.values[0][:, 0, :].T,\n",
    "                    cmap=\"RdBu_r\",\n",
    "                    vmin=-1,\n",
    "                    vmax=1,\n",
    "                    aspect=\"auto\",\n",
    "                )\n",
    "            ),\n",
    "            \"sample_rollout_numpy_diff\",\n",
    "        )\n",
    "\n",
    "        facet.figure.subplots_adjust(top=0.9)\n",
    "        facet.figure.suptitle(f\"Sample Rollouts Diff, {scenario}, {net}\")\n",
    "\n",
    "        # Save figure\n",
    "        plt.savefig(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_{scenario}_{net}.pdf\"\n",
    "        )\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [04:05<00:00, 49.15s/it]\n"
     ]
    }
   ],
   "source": [
    "# Runtime ~5min\n",
    "for scenario in tqdm(sample_rollout_data.scenario.unique()):\n",
    "    for net in sample_rollout_data.net.unique():\n",
    "        if net == \"Ref\":\n",
    "            continue\n",
    "        facet = sns.FacetGrid(\n",
    "            sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "                f\"net == '{net}'\"\n",
    "            ),\n",
    "            col=\"seed\",\n",
    "            col_wrap=5,\n",
    "            aspect=1.5,\n",
    "        )\n",
    "        facet.map(\n",
    "            lambda x, **kwargs: (\n",
    "                plt.imshow(\n",
    "                    x.values[0][:, 0, :].T,\n",
    "                    cmap=\"RdBu_r\",\n",
    "                    vmin=-1,\n",
    "                    vmax=1,\n",
    "                    aspect=\"auto\",\n",
    "                )\n",
    "            ),\n",
    "            \"sample_rollout_numpy_diff_abs\",\n",
    "        )\n",
    "\n",
    "        facet.figure.subplots_adjust(top=0.9)\n",
    "        facet.figure.suptitle(f\"Sample Rollouts Abs Diff, {scenario}, {net}\")\n",
    "\n",
    "        # Save figure\n",
    "        plt.savefig(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_abs_{scenario}_{net}.pdf\"\n",
    "        )\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
