{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Produces qualitative sample rollouts of the trained emulators in 3D\n",
    "\n",
    "First melt the sample rollouts (not done by default)\n",
    "\n",
    "```bash\n",
    "python run.py experiments/broad_comparison_3d.py --dont_melt_loss --dont_melt_metrics --melt_sample_rollouts\n",
    "```\n",
    "\n",
    "Then produce reference rollouts (script needs to be executed out of this folder)\n",
    "\n",
    "```bash\n",
    "python get_ref_sample_rollouts.py\n",
    "```\n",
    "\n",
    "(Important: Has to be executed on the same machine as the training was done on and with the same JAX version to ensure that the same initial conditions are drawn)\n",
    "\n",
    "\n",
    "Processing 5 seeds requires at least 50GB of RAM!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "# APEBench might otherwise set this to \"gpu\"\n",
    "jax.config.update(\"jax_platform_name\", \"cpu\")\n",
    "\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../../apebench/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import apebench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"img/sample_rollouts/\", exist_ok=True)\n",
    "os.makedirs(\"img/sample_rollouts/across_seeds\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Takes ~15 min\n",
    "# sample_rollout_data = pd.read_csv(\"../../melted/broad_comparison_3d/sample_rollout.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cut data mass a bit to make it easier to work with, otherwise processing times\n",
    "# are excessive\n",
    "# sample_rollout_data = sample_rollout_data.query(\"seed < 5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b9cd4c1532d44b0982f44939c2d45229",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sample_rollout_data_chunked = pd.read_csv(\n",
    "    \"../../melted/broad_comparison_3d/sample_rollout_data_five_seeds.csv\",\n",
    "    chunksize=5,\n",
    ")\n",
    "\n",
    "sample_rollout_data_list = []\n",
    "for chunk in tqdm(sample_rollout_data_chunked):\n",
    "    sample_rollout_data_list.append(chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = pd.concat(sample_rollout_data_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Takes ~1:30h for all 20 seeds -> ~5 min for 1 seed\n",
    "sample_rollout_data[\"sample_rollout_numpy\"] = sample_rollout_data[\n",
    "    \"sample_rollout\"\n",
    "].apply(\n",
    "    lambda x: np.array(\n",
    "        json.loads(re.sub(r\"\\bnan\\b\", \"NaN\", re.sub(r\"\\binf\\b\", \"Infinity\", x)))\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_s = []\n",
    "for scene in [\n",
    "    \"3d_phy_unbal_adv\",\n",
    "    \"3d_diff_burgers\",  # Uses the three-channel version in 3D\n",
    "    \"3d_diff_ks\",\n",
    "    \"3d_phy_gs\",\n",
    "    \"3d_phy_sh\",\n",
    "]:\n",
    "    ref = np.load(f\"ref_sample_rollouts/{scene}.npy\")[0]\n",
    "    for s in sample_rollout_data.seed.unique():\n",
    "        ref_data = pd.DataFrame(\n",
    "            {\n",
    "                \"seed\": s,\n",
    "                \"net\": \"Ref\",\n",
    "                \"scenario\": scene,\n",
    "                \"task\": \"predict\",\n",
    "                \"train\": \"one\",\n",
    "                \"scenario_kwargs\": [\n",
    "                    {},\n",
    "                ],\n",
    "                \"sample_index\": 0,\n",
    "                \"sample_rollout\": [\n",
    "                    [],\n",
    "                ],\n",
    "                \"sample_rollout_numpy\": [\n",
    "                    ref,\n",
    "                ],\n",
    "            }\n",
    "        )\n",
    "        data_s.append(ref_data)\n",
    "\n",
    "ref_data_merged = pd.concat(data_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = pd.concat([ref_data_merged, sample_rollout_data])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data.drop(columns=[\"sample_rollout\"], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "del sample_rollout_data_chunked, sample_rollout_data_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2310612/636128203.py:4: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  ).apply(\n"
     ]
    }
   ],
   "source": [
    "sample_rollout_data = sample_rollout_data.groupby(\n",
    "    [\"seed\", \"sample_index\", \"scenario\"],\n",
    "    group_keys=False,\n",
    ").apply(\n",
    "    lambda df: df.assign(\n",
    "        sample_rollout_numpy_diff=lambda x: x[\"sample_rollout_numpy\"]\n",
    "        - x.query(\"net == 'Ref'\")[\"sample_rollout_numpy\"].values[0:1]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assign_order(df: pd.DataFrame, column_name: str, order: list):\n",
    "    changing_dict = {}\n",
    "    changing_dict[column_name] = lambda data: pd.Categorical(\n",
    "        data[column_name], categories=order, ordered=True\n",
    "    )\n",
    "    df = df.assign(**changing_dict)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"net\",\n",
    "    [\n",
    "        \"Ref\",\n",
    "        \"Conv;26;12;relu\",  # 202'489 params, 13 receptive field per direction\n",
    "        \"Res;25;6;relu\",  # 202'876 params, 12 receptive field per direction\n",
    "        \"UNet;11;2;relu\",  # 200'322 params, 29 receptive field per direction\n",
    "        \"Dil;2;27;2;relu\",  # 197'722 params, 20 receptive field per direction\n",
    "        \"FNO;5;7;4;gelu\",  # 196'246 params, inf receptive field\n",
    "    ],\n",
    ")\n",
    "\n",
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"scenario\",\n",
    "    [\n",
    "        \"3d_phy_unbal_adv\",\n",
    "        \"3d_diff_burgers\",  # Uses the three-channel version in 3D\n",
    "        \"3d_diff_ks\",\n",
    "        \"3d_phy_gs\",\n",
    "        \"3d_phy_sh\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data[\"sample_rollout_numpy_diff_abs\"] = sample_rollout_data[\n",
    "    \"sample_rollout_numpy_diff\"\n",
    "].apply(np.abs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = sample_rollout_data.sort_values(\n",
    "    [\"scenario\", \"net\", \"seed\", \"sample_index\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extracted_data = np.stack(\n",
    "#     sample_rollout_data.query(\"scenario == '3d_diff_ks'\")[\"sample_rollout_numpy\"].values\n",
    "# )[:, :10]\n",
    "# list(sample_rollout_data.query(\"scenario == '3d_diff_ks'\")[\"net\"])\n",
    "# ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "#     extracted_data,\n",
    "#     facet_over_channels=False,\n",
    "#     grid=(2, 3),\n",
    "#     figsize=(12, 8),\n",
    "#     titles=list(sample_rollout_data.query(\"scenario == '3d_diff_ks'\")[\"net\"]),\n",
    "# )\n",
    "# ani.save(\"Hello___.mp4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22ff9615eeea4fc09677bcccde1ba170",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3ff3795d87647f08ba0df63d9c1126d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy_diff\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_diff_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59209d144d7140f199d47bb0c14a2939",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(\n",
    "            sub_df[\"sample_rollout_numpy_diff_abs\"].values\n",
    "        )[:, :, 0:1]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_diff_abs_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b2320fd08c54f0ca036f09942459ec7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n",
      "(5, 201, 1, 32, 32, 32)\n"
     ]
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(12, 4),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b1f5e0ceff6423e99e872bbc091b735",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy_diff\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(12, 4),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9df7785e42114cf5a95ace86c756b8ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(\n",
    "            sub_df[\"sample_rollout_numpy_diff_abs\"].values\n",
    "        )[:, :, 0:1]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_3d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(12, 4),\n",
    "            vlim=(-0.1, 0.1) if scenario == \"3d_diff_burgers\" else (-1, 1),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_abs_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
