{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Produces qualitative sample rollouts of the trained emulators in 2D\n",
    "\n",
    "First melt the sample rollouts (not done by default)\n",
    "\n",
    "```bash\n",
    "python run.py experiments/broad_comparison_2d.py --dont_melt_loss --dont_melt_metrics --melt_sample_rollouts\n",
    "```\n",
    "\n",
    "Then produce reference rollouts (script needs to be executed out of this folder)\n",
    "\n",
    "```bash\n",
    "python get_ref_sample_rollouts.py\n",
    "```\n",
    "\n",
    "(Important: Has to be executed on the same machine as the training was done on and with the same JAX version to ensure that the same initial conditions are drawn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "# APEBench might otherwise set this to \"gpu\"\n",
    "jax.config.update(\"jax_platform_name\", \"cpu\")\n",
    "\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../../apebench/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import apebench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"img/sample_rollouts/\", exist_ok=True)\n",
    "os.makedirs(\"img/sample_rollouts/across_seeds\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Should load in ~10 min\n",
    "sample_rollout_data = pd.read_csv(\"../../melted/broad_comparison_2d/sample_rollout.csv\")\n",
    "# sample_rollout_data = pd.read_csv(\"../../melted/broad_comparison_2d/sample_rollout_one_seed.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cut data mass a bit to make it easier to work with, otherwise processing times\n",
    "# are excessive\n",
    "sample_rollout_data = sample_rollout_data.query(\n",
    "    \"seed < 5\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Takes ~4 min per seed\n",
    "sample_rollout_data[\"sample_rollout_numpy\"] = sample_rollout_data[\n",
    "    \"sample_rollout\"\n",
    "].apply(\n",
    "    lambda x: np.array(\n",
    "        json.loads(re.sub(r\"\\bnan\\b\", \"NaN\", re.sub(r\"\\binf\\b\", \"Infinity\", x)))\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_s = []\n",
    "for scene in [\n",
    "    \"2d_phy_aniso_diff\",\n",
    "    \"2d_diff_burgers\",  # Uses the two-channel version in 2D\n",
    "    \"2d_diff_ks\",\n",
    "    \"2d_phy_kolm_flow\",\n",
    "    \"2d_phy_gs_theta\",\n",
    "]:\n",
    "    ref = np.load(f\"ref_sample_rollouts/{scene}.npy\")[0]\n",
    "    for s in sample_rollout_data.seed.unique():\n",
    "        data = pd.DataFrame(\n",
    "            {\n",
    "                \"seed\": s,\n",
    "                \"net\": \"Ref\",\n",
    "                \"scenario\": scene,\n",
    "                \"task\": \"predict\",\n",
    "                \"train\": \"one\",\n",
    "                \"scenario_kwargs\": [\n",
    "                    {},\n",
    "                ],\n",
    "                \"sample_index\": 0,\n",
    "                \"sample_rollout\": [\n",
    "                    [],\n",
    "                ],\n",
    "                \"sample_rollout_numpy\": [\n",
    "                    ref,\n",
    "                ],\n",
    "            }\n",
    "        )\n",
    "        data_s.append(data)\n",
    "\n",
    "ref_data_merged = pd.concat(data_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = pd.concat([ref_data_merged, sample_rollout_data])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = sample_rollout_data.groupby(\n",
    "    [\"seed\", \"sample_index\", \"scenario\"],\n",
    "    group_keys=False,\n",
    ").apply(\n",
    "    lambda df: df.assign(\n",
    "        sample_rollout_numpy_diff=lambda x: x[\"sample_rollout_numpy\"]\n",
    "        - x.query(\"net == 'Ref'\")[\"sample_rollout_numpy\"].values[0:1]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assign_order(df: pd.DataFrame, column_name: str, order: list):\n",
    "    changing_dict = {}\n",
    "    changing_dict[column_name] = lambda data: pd.Categorical(\n",
    "        data[column_name], categories=order, ordered=True\n",
    "    )\n",
    "    df = df.assign(**changing_dict)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"net\",\n",
    "    [\n",
    "        \"Ref\",\n",
    "        \"Conv;26;11;relu\",  # 61'595 params, 12 receptive field per direction\n",
    "        \"UNet;10;2;relu\",  # 55'661 params, 29 receptive field per direction\n",
    "        \"Res;26;5;relu\",  # 61'179 params, 10 receptive field per direction\n",
    "        \"FNO;10;6;4;gelu\",  # 57'787 params, inf receptive field per direction\n",
    "        \"Dil;2;26;2;relu\",  # 61'699 params, 20 receptive field per direction\n",
    "    ],\n",
    ")\n",
    "\n",
    "sample_rollout_data = assign_order(\n",
    "    sample_rollout_data,\n",
    "    \"scenario\",\n",
    "    [\n",
    "        \"2d_phy_aniso_diff\",\n",
    "        \"2d_diff_burgers\",  # Uses the two-channel version in 2D\n",
    "        \"2d_diff_ks\",\n",
    "        \"2d_phy_kolm_flow\",\n",
    "        \"2d_phy_gs_theta\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data[\"sample_rollout_numpy_diff_abs\"] = sample_rollout_data[\n",
    "    \"sample_rollout_numpy_diff\"\n",
    "].apply(np.abs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_rollout_data = sample_rollout_data.sort_values(\n",
    "    [\"scenario\", \"net\", \"seed\", \"sample_index\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extracted_data = np.stack(sample_rollout_data.query(\"scenario == '2d_diff_ks'\")[\"sample_rollout_numpy\"].values)#[:, :20]\n",
    "# sample_rollout_data.query(\"scenario == '2d_phy_aniso_diff'\")[\"net\"]\n",
    "# extracted_data.shape\n",
    "# ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "#     jnp.array(extracted_data),\n",
    "#     facet_over_channels=False,\n",
    "#     grid=(2, 3),\n",
    "#     figsize=(12, 8),\n",
    "#     titles=list(sample_rollout_data.query(\"scenario == '2d_phy_aniso_diff'\")[\"net\"]),\n",
    "# )\n",
    "# ani.save(\"hello.mp4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/5 [00:00<?, ?it/s]CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "100%|██████████| 5/5 [37:38<00:00, 451.78s/it]\n"
     ]
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [37:06<00:00, 445.33s/it]\n"
     ]
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy_diff\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_diff_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [36:50<00:00, 442.00s/it]\n"
     ]
    }
   ],
   "source": [
    "for seed in tqdm(sample_rollout_data.seed.unique()):\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"seed == {seed}\").query(\n",
    "            f\"scenario == '{scenario}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(\n",
    "            sub_df[\"sample_rollout_numpy_diff_abs\"].values\n",
    "        )[:, :, 0:1]\n",
    "        extracted_net_names = sub_df[\"net\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(2, 3),\n",
    "            figsize=(12, 8),\n",
    "            titles=extracted_net_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/sample_rollouts_diff_abs_scene={scenario}_seed={seed}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6/6 [19:02<00:00, 190.44s/it]\n"
     ]
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(12, 4),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/6 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "100%|██████████| 6/6 [24:56<00:00, 249.38s/it]\n"
     ]
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(sub_df[\"sample_rollout_numpy_diff\"].values)[\n",
    "            :, :, 0:1\n",
    "        ]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(15, 4),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6/6 [25:07<00:00, 251.33s/it]\n"
     ]
    }
   ],
   "source": [
    "for net in tqdm(sample_rollout_data.net.unique()):\n",
    "    if net == \"Ref\":\n",
    "        continue\n",
    "    for scenario in sample_rollout_data.scenario.unique():\n",
    "        sub_df = sample_rollout_data.query(f\"scenario == '{scenario}'\").query(\n",
    "            f\"net == '{net}'\"\n",
    "        )\n",
    "        extracted_trajectories = np.stack(\n",
    "            sub_df[\"sample_rollout_numpy_diff_abs\"].values\n",
    "        )[:, :, 0:1]\n",
    "        extracted_seed_names = sub_df[\"seed\"].values\n",
    "        ani = apebench.exponax.viz.animate_state_2d_facet(\n",
    "            extracted_trajectories,\n",
    "            facet_over_channels=False,\n",
    "            grid=(1, 5),\n",
    "            figsize=(15, 4),\n",
    "            titles=extracted_seed_names,\n",
    "        )\n",
    "        ani.save(\n",
    "            f\"img/sample_rollouts/across_seeds/sample_rollouts_diff_abs_net={net}_scene={scenario}.mp4\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_gpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
