{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fcc7baef",
   "metadata": {},
   "source": [
    "# Evaluation: main table, turbulence evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "127352b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632e93db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"6\"\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "798e92a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence, Optional, Callable, Dict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from scipy.stats import pearsonr\n",
    "from collections import defaultdict\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", message=\"__array_wrap__ must accept context*\")\n",
    "\n",
    "from neural_fields.trad import zfp_recon, wavelet_recon, pca_recon, jpeg2000_recon\n",
    "from neural_fields.data import CycloneNFDataset\n",
    "from neural_fields.nf_utils import sample_field, load_nf, compress_weights, endpoint_error\n",
    "from neural_fields.gk_losses import diagnostics\n",
    "from dataset.cyclone_diff import CycloneDiffusionDataset\n",
    "\n",
    "from train.integrals import FluxIntegral\n",
    "from utils import load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e3358d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d997f1e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "NAME_TIMESTEPS = [100, 120, 140, 160, 180, 200]\n",
    "TIMESTEPS = [0, 1, 2, 3, 4, 5]  # these are actually [100, 120, 140, 160, 180, 200]\n",
    "TRAJECTORIES = [\n",
    "    \"iteration_13\",\n",
    "    \"iteration_115\",\n",
    "    \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    \"iteration_200\",\n",
    "    \"iteration_210\",\n",
    "    \"iteration_212\",\n",
    "]\n",
    "\n",
    "BASE_PATH = \"<path>\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c384ecb5",
   "metadata": {},
   "source": [
    "## Load all models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6525ffd8",
   "metadata": {},
   "source": [
    "### Neural fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bd4d70a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "CKP_DIR = f\"{BASE_PATH}/nf_ckps_hf\"\n",
    "\n",
    "nfs = {}\n",
    "for traj in TRAJECTORIES:\n",
    "    nfs[traj] = {}\n",
    "    for t, tn in zip(TIMESTEPS, NAME_TIMESTEPS):\n",
    "        ckp_name = f\"mlp_{traj}_t{tn}_x1167.pt\"\n",
    "        nfs[traj][t] = load_nf(f\"{CKP_DIR}/{ckp_name}\", device)\n",
    "\n",
    "int_nfs = {}\n",
    "for traj in TRAJECTORIES:\n",
    "    int_nfs[traj] = {}\n",
    "    for t, tn in zip(TIMESTEPS, NAME_TIMESTEPS):\n",
    "        ckp_name = f\"int_mlp_{traj}_t{tn}_x1167.pt\"\n",
    "        int_nfs[traj][t] = load_nf(f\"{CKP_DIR}/{ckp_name}\", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "379f2b2f",
   "metadata": {},
   "source": [
    "### Autoencoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab44f101",
   "metadata": {},
   "outputs": [],
   "source": [
    "int_ae = f\"{BASE_PATH}/autoencoders/ae/best.pth\" \n",
    "int_ae, _, int_ae_cfg = load_model(int_ae, device=device, load_peft=True)\n",
    "int_ae.eval()\n",
    "\n",
    "cr_77368_peft = f\"{BASE_PATH}/autoencoders/vqvae/best.pth\" \n",
    "int_vqvae, _, int_vqvae_cfg = load_model(cr_77368_peft, device=device, load_peft=True)\n",
    "int_vqvae.eval()\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c57d7c6d",
   "metadata": {},
   "source": [
    "## Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2327ddd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ml_eval(\n",
    "    pred_df: torch.Tensor,\n",
    "    gt_df: torch.Tensor,\n",
    "    pred_phi: torch.Tensor,\n",
    "    gt_phi: torch.Tensor,\n",
    "    pred_eflux: torch.Tensor,\n",
    "    gt_eflux: torch.Tensor,\n",
    "    compressed_size: int = None,\n",
    "):\n",
    "    metrics = {}\n",
    "\n",
    "    metrics[\"mse\"] = ((pred_df.cpu() - gt_df.cpu()) ** 2).mean()\n",
    "    metrics[\"l1\"] = (pred_df.cpu() - gt_df.cpu()).abs().mean()\n",
    "    metrics[\"psnr\"] = 10 * torch.log10(gt_df.max() ** 2 / metrics[\"mse\"])\n",
    "    metrics[\"phi_mse\"] = ((pred_phi - gt_phi) ** 2).mean()\n",
    "    metrics[\"phi_l1\"] = (pred_phi - gt_phi).abs().mean()\n",
    "    metrics[\"phi_psnr\"] = 10 * torch.log10(gt_phi.max() ** 2 / metrics[\"phi_mse\"])\n",
    "    metrics[\"eflux_l1\"] = (pred_eflux - gt_eflux).abs().mean()\n",
    "    \n",
    "    if compressed_size is not None:\n",
    "        num_pixels = pred_df.numel()\n",
    "        metrics[\"bpp\"] = (compressed_size * 8) / num_pixels\n",
    "\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ffef51a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def run_eval_diagnostics(\n",
    "    trajectories: str,\n",
    "    timesteps: Sequence[int],\n",
    "    model: Optional[Callable] = None,\n",
    "    model_name: Optional[str] = None,\n",
    "    train_norm_stats: Optional[Dict] = None\n",
    "):\n",
    "    if model_name is None:\n",
    "        model_name = \"GT\"\n",
    "    ml_metrics = defaultdict(list)\n",
    "    phy_diag = {}\n",
    "    for traj in tqdm(trajectories):\n",
    "        data = CycloneNFDataset(\n",
    "            traj, timesteps=timesteps, path=f\"{BASE_PATH}/h5\", normalize=None, realpotens=True\n",
    "        )\n",
    "\n",
    "        DS = 0.0625\n",
    "\n",
    "        gt_dfs = [data.df[:, t] for t in range(data.df.shape[1])] if data.ndim > 5 else [data.df]\n",
    "        compressed_size = None\n",
    "        if model_name == \"GT\":\n",
    "            dfs = gt_dfs\n",
    "\n",
    "        if \"NF\" in model_name:\n",
    "            # nf reconstruct\n",
    "            dfs = []\n",
    "            compressed_size = 0\n",
    "            for t in range(len(timesteps)):\n",
    "                # TODO load config\n",
    "                real_t = timesteps[t]\n",
    "                if \"ZFP\" in model_name:\n",
    "                    nf = model[traj][real_t].to(device)\n",
    "                    nf, _, nf_zfp_size = compress_weights(nf, tolerance=1e-3)\n",
    "                    compressed_size += nf_zfp_size\n",
    "                if \"ZipNN\" in model_name:\n",
    "                    nf = model[traj][real_t].to(device)\n",
    "                    nf, _, nf_zipnn_size = compress_weights(nf, method=\"zipnn\")\n",
    "                    compressed_size += nf_zipnn_size\n",
    "                else:\n",
    "                    nf = model[traj][real_t].to(device)\n",
    "                    compressed_size += sum(p.nbytes for p in nf.parameters())\n",
    "                # create new normalized dataset\n",
    "                nf_data = CycloneNFDataset(\n",
    "                    traj,\n",
    "                    path=f\"{BASE_PATH}/h5\",\n",
    "                    timesteps=real_t,\n",
    "                    normalize=\"zscore\",\n",
    "                    normalize_coords=False,\n",
    "                    realpotens=True,\n",
    "                )\n",
    "                # automatically denormalizes\n",
    "                dfs.append(sample_field(nf, nf_data, device).cpu())\n",
    "\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "        if \"AE\" in model_name or \"VQ-VAE\" in model_name:\n",
    "            dfs = []\n",
    "            compressed_size = 0\n",
    "\n",
    "            for t in range(data.df.shape[1]):\n",
    "                ae, cfg = model\n",
    "                ae = ae.to(device)\n",
    "                # bit triky here, need the original config and datasets (normalization)\n",
    "                \n",
    "                valdata_ae = CycloneDiffusionDataset(\n",
    "                    path=f\"{BASE_PATH}/h5\",\n",
    "                    split=\"train\",\n",
    "                    input_fields=[\"df\", \"phi\", \"flux\"],\n",
    "                    random_seed=cfg.seed,\n",
    "                    normalization=cfg.dataset.normalization,\n",
    "                    normalization_scope=cfg.dataset.normalization_scope,\n",
    "                    normalization_stats=train_norm_stats,\n",
    "                    trajectories=[f\"{traj}.h5\"],\n",
    "                    separate_zf=cfg.dataset.separate_zf,\n",
    "                    real_potens=True,\n",
    "                    stage=cfg.stage,\n",
    "                    conditions=[\"itg\", \"dg\", \"s_hat\", \"q\"],\n",
    "                )\n",
    "                # ae reconstruct\n",
    "                sample = valdata_ae[t]\n",
    "                df = sample.df.unsqueeze(0).to(device)\n",
    "                condition = sample.conditioning.unsqueeze(0).to(device)\n",
    "                ae_df = ae(df, condition=condition)[\"df\"].cpu().squeeze(0)\n",
    "                # important: denormalize\n",
    "                ae_df = valdata_ae.denormalize(0, df=ae_df)\n",
    "                if ae_df.shape[0] == 4:\n",
    "                    ae_df = ae_df[[0, 1]] + ae_df[[2, 3]]\n",
    "                dfs.append(ae_df)\n",
    "\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "                if \"VQ-VAE\" in model_name:\n",
    "                    model_output = ae(df, condition=condition)\n",
    "                    indices = model_output[\"vq_indices\"]\n",
    "\n",
    "                    # (codebook_size=8192 fits in int16)\n",
    "                    indices_int16 = indices.to(torch.int16)\n",
    "                    compressed_size += indices_int16.nbytes\n",
    "                else:\n",
    "                    # Regular AE\n",
    "                    latent = ae.encode(df, condition=condition)[0]\n",
    "                    compressed_size += latent.nbytes\n",
    "\n",
    "        if model_name in [\"ZFP\", \"Wavelet\", \"PCA\", \"JPEG2000\"]:\n",
    "            dfs = []\n",
    "            compressed_size = 0\n",
    "            for t in range(len(timesteps)):\n",
    "                tdf = data.full_df[:, t] if data.ndim > 5 else data.full_df\n",
    "                recon, _, cs = model(tdf)\n",
    "                dfs.append(recon)\n",
    "                compressed_size += cs\n",
    "        if compressed_size:\n",
    "            ml_metrics[\"cr\"].append(data.full_df.nbytes / compressed_size)\n",
    "\n",
    "        geom = {k: v.unsqueeze(0) for k, v in data.geom.items()}\n",
    "        integral = FluxIntegral(\n",
    "            real_potens=True, flux_fields=True, spectral_potens=True\n",
    "        )\n",
    "\n",
    "        # physics metrics\n",
    "        phy_diag[traj] = defaultdict(list)\n",
    "        phy_diag_gt = []\n",
    "        for pred_df, gt_df in zip(dfs, gt_dfs):\n",
    "            pred_phi, (_, pred_eflux, _) = integral(geom, pred_df.unsqueeze(0))\n",
    "            pred_d = diagnostics(pred_phi.squeeze(), pred_eflux.squeeze(), ds=DS)\n",
    "            pred_d = {k: v.numpy() for k, v in pred_d.items()}\n",
    "            phy_diag[traj][\"phi\"].append(np.abs(pred_phi[:, 7, :].squeeze()))\n",
    "            phy_diag[traj][\"kxspec\"].append(pred_d[\"kxspec\"])\n",
    "            phy_diag[traj][\"kyspec\"].append(pred_d[\"kyspec\"])\n",
    "            phy_diag[traj][\"qspec\"].append(pred_d[\"qspec\"])\n",
    "\n",
    "            if model_name != \"GT\":\n",
    "                gt_phi, (_, gt_eflux, _) = integral(geom, gt_df.unsqueeze(0))\n",
    "                gt_d = diagnostics(gt_phi.squeeze(), gt_eflux.squeeze(), ds=DS)\n",
    "                gt_d = {k: v.numpy() for k, v in gt_d.items()}\n",
    "                phy_diag_gt.append(gt_d)\n",
    "\n",
    "        # time-averaged metrics\n",
    "        if model_name != \"GT\":\n",
    "            from scipy.stats import pearsonr, spearmanr, wasserstein_distance\n",
    "\n",
    "            pred_kyspec_ta = np.stack(phy_diag[traj][\"kyspec\"], 0).mean(0)\n",
    "            pred_qspec_ta = np.stack(phy_diag[traj][\"qspec\"], 0).mean(0)\n",
    "            gt_kyspec_ta = np.stack([d[\"kyspec\"] for d in phy_diag_gt], 0).mean(0)\n",
    "            gt_qspec_ta = np.stack([d[\"qspec\"] for d in phy_diag_gt], 0).mean(0)\n",
    "\n",
    "            ml_metrics[\"kyspec_pc\"].append(pearsonr(pred_kyspec_ta, gt_kyspec_ta)[0])\n",
    "            ml_metrics[\"qspec_pc\"].append(pearsonr(pred_qspec_ta, gt_qspec_ta)[0])\n",
    "\n",
    "            ml_metrics[\"kyspec_l1\"].append(np.abs(pred_kyspec_ta - gt_kyspec_ta).sum())\n",
    "            ml_metrics[\"qspec_l1\"].append(np.abs(pred_qspec_ta - gt_qspec_ta).sum())\n",
    "\n",
    "            ml_metrics[\"kyspec_sc\"].append(spearmanr(pred_kyspec_ta, gt_kyspec_ta)[0])\n",
    "            ml_metrics[\"qspec_sc\"].append(spearmanr(pred_qspec_ta, gt_qspec_ta)[0])\n",
    "\n",
    "            pred_kyspec_ta /= pred_kyspec_ta.sum()\n",
    "            gt_kyspec_ta /= gt_kyspec_ta.sum()\n",
    "            pred_qspec_ta /= pred_qspec_ta.sum()\n",
    "            gt_qspec_ta /= gt_qspec_ta.sum()\n",
    "            ml_metrics[\"kyspec_wd\"].append(\n",
    "                wasserstein_distance(pred_kyspec_ta, gt_kyspec_ta)\n",
    "            )\n",
    "            ml_metrics[\"qspec_wd\"].append(\n",
    "                wasserstein_distance(pred_qspec_ta, gt_qspec_ta)\n",
    "            )\n",
    "\n",
    "        # ml metrics\n",
    "        integral = FluxIntegral()\n",
    "        if model_name != \"GT\":\n",
    "            for pred_df, gt_df in zip(dfs, gt_dfs):\n",
    "                pred_phi, (_, pred_eflux, _) = integral(geom, pred_df.unsqueeze(0))\n",
    "                gt_phi, (_, gt_eflux, _) = integral(geom, gt_df.unsqueeze(0))\n",
    "\n",
    "                ml_eval_dict = ml_eval(\n",
    "                    pred_df,\n",
    "                    gt_df,\n",
    "                    pred_phi,\n",
    "                    gt_phi,\n",
    "                    pred_eflux,\n",
    "                    gt_eflux,\n",
    "                    compressed_size,\n",
    "                )\n",
    "                for k, v in ml_eval_dict.items():\n",
    "                    ml_metrics[k].append(v)\n",
    "\n",
    "            if len(gt_dfs) > 1:\n",
    "                gt_df_ = np.stack([d.cpu().numpy() for d in gt_dfs])\n",
    "                pred_df_ = np.stack([d.cpu().numpy() for d in dfs])\n",
    "                ml_metrics[\"endpoint\"].append(endpoint_error(gt_df_, pred_df_))\n",
    "\n",
    "    ml_metrics = {k: (np.mean(v), np.std(v)) for k, v in ml_metrics.items()}\n",
    "    if model_name != \"GT\":\n",
    "        ml_metrics[\"bpp\"] = ml_metrics[\"bpp\"] * len(dfs)\n",
    "    print(model_name, \"done!\")\n",
    "    return {model_name: phy_diag}, {model_name: ml_metrics}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15c51269",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f3de821",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_diagnostics = {}\n",
    "full_metrics = {}\n",
    "\n",
    "gt_diag, _ = run_eval_diagnostics(TRAJECTORIES, timesteps=TIMESTEPS)\n",
    "full_diagnostics.update(gt_diag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7d1997",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ZFP\n",
    "zfp_diag, zfp_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=zfp_recon, model_name=\"ZFP\"\n",
    ")\n",
    "full_diagnostics.update(zfp_diag)\n",
    "full_metrics.update(zfp_metrics)\n",
    "\n",
    "# wavelet\n",
    "wft_diag, wft_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=wavelet_recon, model_name=\"Wavelet\"\n",
    ")\n",
    "full_diagnostics.update(wft_diag)\n",
    "full_metrics.update(wft_metrics)\n",
    "\n",
    "# PCA\n",
    "pca_diag, pca_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=pca_recon, model_name=\"PCA\"\n",
    ")\n",
    "full_diagnostics.update(pca_diag)\n",
    "full_metrics.update(pca_metrics)\n",
    "\n",
    "# JPEG2000\n",
    "jp2k_diag, jp2k_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=jpeg2000_recon, model_name=\"JPEG2000\"\n",
    ")\n",
    "full_diagnostics.update(jp2k_diag)\n",
    "full_metrics.update(jp2k_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b73fdc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NF\n",
    "nf_diag, nf_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=nfs, model_name=\"NF\"\n",
    ")\n",
    "full_diagnostics.update(nf_diag)\n",
    "full_metrics.update(nf_metrics)\n",
    "\n",
    "# NF + int\n",
    "nf_diag, nf_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES, timesteps=TIMESTEPS, model=int_nfs, model_name=\"PINC-NF\"\n",
    ")\n",
    "full_diagnostics.update(nf_diag)\n",
    "full_metrics.update(nf_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db252186",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AE + INT\n",
    "int_ae_diag, int_ae_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES,\n",
    "    timesteps=TIMESTEPS,\n",
    "    model=(int_ae, int_ae_cfg),\n",
    "    model_name=\"AE + EVA\",\n",
    "    train_norm_stats=pickle.load(open(f\"{BASE_PATH}/hf_norm_stats.pkl\", \"rb\"))\n",
    ")\n",
    "full_diagnostics.update(int_ae_diag)\n",
    "full_metrics.update(int_ae_metrics)\n",
    "\n",
    "# VQ-VAE + INT\n",
    "int_vqvae_diag, int_vqvae_metrics = run_eval_diagnostics(\n",
    "    TRAJECTORIES,\n",
    "    timesteps=TIMESTEPS,\n",
    "    model=(int_vqvae, int_vqvae_cfg),\n",
    "    model_name=\"VQ-VAE + EVA\",\n",
    "    train_norm_stats=pickle.load(open(f\"{BASE_PATH}/hf_norm_stats.pkl\", \"rb\"))\n",
    ")\n",
    "full_diagnostics.update(int_vqvae_diag)\n",
    "full_metrics.update(int_vqvae_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba1ebe63",
   "metadata": {},
   "source": [
    "## Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ee936d",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_keys = [\"psnr\", \"endpoint\", \"eflux_l1\", \"phi_psnr\", \"kyspec_wd\", \"qspec_wd\"]\n",
    "\n",
    "direction = {\n",
    "    \"l1\": \"min\",\n",
    "    \"psnr\": \"max\",\n",
    "    \"bpp\": \"min\",\n",
    "    \"endpoint\": \"min\",\n",
    "    \"eflux_l1\": \"min\",\n",
    "    \"phi_psnr\": \"max\",\n",
    "    \"kyspec_pc\": \"max\",\n",
    "    \"qspec_pc\": \"max\",\n",
    "    \"kyspec_sc\": \"max\",\n",
    "    \"qspec_sc\": \"max\",\n",
    "    \"kyspec_wd\": \"min\",\n",
    "    \"qspec_wd\": \"min\",\n",
    "    \"kyspec_shd\": \"min\",\n",
    "    \"qspec_shd\": \"min\",\n",
    "    \"kyspec_l1\": \"min\",\n",
    "    \"qspec_l1\": \"min\",\n",
    "}\n",
    "\n",
    "# values_per_metric = {k: [] for k in metrics_keys}\n",
    "# for m, vals in full_metrics.items():\n",
    "#     if m != \"GT\":\n",
    "#         for k in metrics_keys:\n",
    "#             values_per_metric[k].append(vals[k][0])\n",
    "\n",
    "# # Compute column medians for thresholding\n",
    "# medians = {k: np.median(values_per_metric[k]) for k in metrics_keys}\n",
    "\n",
    "# best_idx, second_idx = {}, {}\n",
    "# for k in metrics_keys:\n",
    "#     arr = np.array(values_per_metric[k])\n",
    "#     if direction[k] == \"min\":\n",
    "#         order = np.argsort(arr)\n",
    "#     else:\n",
    "#         order = np.argsort(-arr)\n",
    "#     best_idx[k] = order[0]\n",
    "#     second_idx[k] = order[1]\n",
    "\n",
    "for i, m in enumerate(full_metrics):\n",
    "    vals = full_metrics[m]\n",
    "    cr = f\"{int(vals['cr'][0])}$\\\\times$\"\n",
    "\n",
    "    row_entries = []\n",
    "    for j, k in enumerate(metrics_keys):\n",
    "        val_mean, val_std = vals[k][0], vals[k][1]\n",
    "        if k == \"bpp\":\n",
    "            formatted = f\"{val_mean / 34:.4f}\"\n",
    "        elif \"spec\" in k:\n",
    "            formatted = f\"{val_mean:.4f}\"\n",
    "        else:\n",
    "            formatted = f\"{val_mean:.2f}\"\n",
    "        if val_std > 1e-3 and k != \"bpp\":\n",
    "            formatted += r\"$_{\\pm \" + f\"{val_std:.2f}\" + r\"}$\"\n",
    "\n",
    "        # if i == best_idx[k]:\n",
    "        #     formatted = r\"\\textbf{\" + formatted + \"}\"\n",
    "        # elif i == second_idx[k]:\n",
    "        #     formatted = r\"\\underline{\" + formatted + \"}\"\n",
    "\n",
    "        row_entries.append(formatted)\n",
    "\n",
    "    rule = \"\"\n",
    "    # if \"VQ-VAE\" in m:\n",
    "    #     rule = r\"\\midrule\"\n",
    "    # if i == len(full_metrics) - 1:\n",
    "    #     rule = r\"\\bottomrule\"\n",
    "\n",
    "    # print(rf\"| {m:<8} | {cr} | \" + \" | \".join(row_entries) + \"  | \")\n",
    "    print(rf\"{m:<8} & {cr} & \" + \" & \".join(row_entries) + rf\" \\\\{rule}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87cd428",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f65b8c27",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_keys = [\"l1\", \"bpp\", \"phi_l1\", \"kyspec_pc\", \"qspec_pc\", \"kyspec_l1\", \"qspec_l1\"]\n",
    "\n",
    "direction = {\n",
    "    \"l1\": \"min\",\n",
    "    \"psnr\": \"max\",\n",
    "    \"bpp\": \"min\",\n",
    "    \"phi_l1\": \"min\",\n",
    "    \"eflux_l1\": \"min\",\n",
    "    \"phi_psnr\": \"max\",\n",
    "    \"kyspec_pc\": \"max\",\n",
    "    \"qspec_pc\": \"max\",\n",
    "    \"kyspec_sc\": \"max\",\n",
    "    \"qspec_sc\": \"max\",\n",
    "    \"kyspec_wd\": \"min\",\n",
    "    \"qspec_wd\": \"min\",\n",
    "    \"kyspec_shd\": \"min\",\n",
    "    \"qspec_shd\": \"min\",\n",
    "    \"kyspec_l1\": \"min\",\n",
    "    \"qspec_l1\": \"min\",\n",
    "}\n",
    "\n",
    "values_per_metric = {k: [] for k in metrics_keys}\n",
    "for m, vals in full_metrics.items():\n",
    "    for k in metrics_keys:\n",
    "        values_per_metric[k].append(vals[k])\n",
    "\n",
    "# best_idx, second_idx = {}, {}\n",
    "# for k in metrics_keys:\n",
    "#     arr = np.array(values_per_metric[k])\n",
    "#     if direction[k] == \"min\":\n",
    "#         order = np.argsort(arr)\n",
    "#     else:\n",
    "#         order = np.argsort(-arr)\n",
    "#     best_idx[k] = order[0]\n",
    "#     second_idx[k] = order[1]\n",
    "\n",
    "# print(r\"\\begin{tabular}{l|c|cccc}\")\n",
    "# print(r\"\\toprule\")\n",
    "# print(\n",
    "#     r\"         & \\multicolumn{1}{c}{Integrals $\\boldsymbol{\\phi}$}                                              & \\multicolumn{4}{c}{Turbulence $Q^{\\text{spec}}, k_y^{\\text{spec}}$}                                        \\\\ \\midrule\"\n",
    "# )\n",
    "# print(\n",
    "#     r\"         & $\\text{L1}(\\boldsymbol{\\phi})$ $\\downarrow$ & $\\text{PC}(\\overline{k_y^{\\text{spec}}})$ $\\uparrow$ & $\\text{PC}(\\overline{Q^{\\text{spec}}})$ $\\uparrow$ & $\\text{L1}(\\overline{k_y^{\\text{spec}}})$ $\\uparrow$ & $\\text{L1}(\\overline{Q^{\\text{spec}}})$ $\\uparrow$ \\\\ \\midrule\"\n",
    "# )\n",
    "\n",
    "for i, m in enumerate(full_metrics):\n",
    "    vals = full_metrics[m]\n",
    "    cr = f\"{int(vals['cr'][0])}$\\\\times$\"\n",
    "\n",
    "    row_entries = []\n",
    "    for j, k in enumerate(metrics_keys):\n",
    "        val_mean, val_std = vals[k][0], vals[k][1]\n",
    "        if k == \"bpp\":\n",
    "            formatted = f\"{val_mean:.3f}\"\n",
    "        elif \"spec\" in k:\n",
    "            formatted = f\"{val_mean:.4f}\"\n",
    "        else:\n",
    "            formatted = f\"{val_mean:.2f}\"\n",
    "        if val_std > 1e-3 and k != \"bpp\":\n",
    "            formatted += r\"$_{\\pm \" + f\"{val_std:.2f}\" + r\"}$\"\n",
    "\n",
    "        # if i == best_idx[k]:\n",
    "        #     formatted = r\"\\textbf{\" + formatted + \"}\"\n",
    "        # elif i == second_idx[k]:\n",
    "        #     formatted = r\"\\underline{\" + formatted + \"}\"\n",
    "\n",
    "        row_entries.append(formatted)\n",
    "\n",
    "    rule = \"\"\n",
    "    if \"VQ-VAE\" in m:\n",
    "        rule = r\"\\midrule\"\n",
    "    if i == len(full_metrics) - 1:\n",
    "        rule = r\"\\bottomrule\"\n",
    "\n",
    "    print(rf\"{m:<8} & \" + \" & \".join(row_entries) + rf\" \\\\{rule}\")\n",
    "\n",
    "# print(r\"\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b657df6",
   "metadata": {},
   "source": [
    "## Temporal table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "605731c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_keys = [\"endpoint\", \"dmd_bmode\", \"dmd_dmd\"]\n",
    "\n",
    "direction = {\n",
    "    \"endpoint\": \"min\",\n",
    "    \"dmd_bmode\": \"min\",\n",
    "    \"dmd_dmd\": \"min\"\n",
    "}\n",
    "\n",
    "values_per_metric = {k: [] for k in metrics_keys}\n",
    "for m, vals in full_metrics.items():\n",
    "    for k in metrics_keys:\n",
    "        values_per_metric[k].append(vals[k])\n",
    "\n",
    "medians = {k: np.median(values_per_metric[k]) for k in metrics_keys}\n",
    "\n",
    "for i, (m, vals) in enumerate(full_metrics.items()):\n",
    "    row_entries = []\n",
    "    for j, k in enumerate(metrics_keys):\n",
    "        val_mean, val_std = vals[k][0], vals[k][1]\n",
    "        formatted = f\"{val_mean:.3f}\" + r\"$_{\\pm \" + f\"{val_std:.2f}\" + r\"}$\"\n",
    "        row_entries.append(formatted)\n",
    "\n",
    "    rule = \"\"\n",
    "    if \"VQ-VAE\" in m:\n",
    "        rule = r\" \\midrule\"\n",
    "    if i == len(full_metrics) - 1:\n",
    "        rule = r\" \\bottomrule\"\n",
    "\n",
    "    print(f\"    {m:<16} & \" + \" & \".join(row_entries) + rf\" \\\\{rule}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4feda7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23812c50",
   "metadata": {},
   "source": [
    "## Quantitative plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c2f958",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_WHITELIST = [\"NF\", \"VQ-VAE + EVA\", \"ZFP\", \"PCA\"]\n",
    "TRAJ_WHITELIST = [\n",
    "    #   \"iteration_13\",\n",
    "    #   \"iteration_115\",\n",
    "    #   \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    #   \"iteration_200\",\n",
    "    #   \"iteration_210\",\n",
    "    #   \"iteration_212\",\n",
    "]\n",
    "fig, ax = plt.subplots(\n",
    "    len(MODEL_WHITELIST),\n",
    "    len(TRAJ_WHITELIST),\n",
    "    figsize=(5 * len(TRAJ_WHITELIST), int(1.7 * len(MODEL_WHITELIST))),\n",
    "    gridspec_kw={\"hspace\": 0.0},\n",
    ")\n",
    "plasma = matplotlib.colormaps[\"plasma\"]\n",
    "\n",
    "row = 0\n",
    "for model_name in MODEL_WHITELIST:\n",
    "    cyc_diag = full_diagnostics[model_name]\n",
    "    if \"EVA\" in model_name:\n",
    "        model_name = model_name.replace(\" + EVA\", \"\")\n",
    "\n",
    "    col = 0\n",
    "    for case_name, diag in cyc_diag.items():\n",
    "        if case_name not in TRAJ_WHITELIST:\n",
    "            continue\n",
    "        kys = diag[\"kyspec\"]\n",
    "        gt_kys = full_diagnostics[\"GT\"][case_name][\"kyspec\"]\n",
    "\n",
    "        n_lines = len(kys)\n",
    "        for t in range(n_lines):\n",
    "            color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "            ax[row, col].plot(\n",
    "                gt_kys[t][1:],\n",
    "                linewidth=2,\n",
    "                color=color,\n",
    "                alpha=0.5,\n",
    "                zorder=0,\n",
    "                linestyle=\"--\",\n",
    "            )\n",
    "            model_curve = kys[t][1:]\n",
    "            gt_curve = gt_kys[t][1:]\n",
    "            ax[row, col].plot(model_curve, linewidth=3, color=color, zorder=1)\n",
    "\n",
    "            # corr, _ = pearsonr(model_curve, gt_curve)\n",
    "            # corr, _ = pearsonr(np.log(model_curve), np.log(gt_curve))\n",
    "            # ax[row, col].text(\n",
    "            #     0.98,\n",
    "            #     0.98 - t * 0.1,\n",
    "            #     rf\"$PC={corr:.2f}$\",\n",
    "            #     color=color,\n",
    "            #     fontsize=14,\n",
    "            #     transform=ax[row, col].transAxes,\n",
    "            #     ha=\"right\",\n",
    "            #     va=\"top\",\n",
    "            # )\n",
    "\n",
    "        ax[row, col].tick_params(axis=\"x\", which=\"both\", labelsize=12)\n",
    "        ax[row, col].tick_params(axis=\"y\", which=\"both\", labelsize=12)\n",
    "\n",
    "        if col == 0 and row == (len(MODEL_WHITELIST) // 2):\n",
    "            ax[row, col].set_ylabel(r\"$|\\phi(k_y)|^2$\", fontsize=26)\n",
    "\n",
    "        if row == len(MODEL_WHITELIST) - 1:\n",
    "            ax[row, col].set_xlabel(r\"$k_y$\", fontsize=26)\n",
    "        else:\n",
    "            ax[row, col].set_xticklabels([])\n",
    "            ax[row, col].tick_params(axis=\"x\", which=\"both\", length=0)\n",
    "\n",
    "        ymin, ymax = ax[row, col].get_ylim()\n",
    "        yticks = [\n",
    "            ymin + 0.1 * (ymax - ymin),\n",
    "            ymin + 0.5 * (ymax - ymin),\n",
    "            ymin + 0.9 * (ymax - ymin),\n",
    "        ]\n",
    "        ax[row, col].set_yticks([round(yt, 1) for yt in yticks])\n",
    "        ax[row, col].grid(True, linestyle=\"--\", alpha=0.6)\n",
    "\n",
    "        ax[row, col].set_xscale(\"log\")\n",
    "        ax[row, col].set_yscale(\"log\")\n",
    "\n",
    "        # ax[row, col].set_xlim(0, 15)\n",
    "\n",
    "        col += 1\n",
    "    col = col - 1\n",
    "\n",
    "    ax[row, -1].text(\n",
    "        1.05,\n",
    "        0.5,\n",
    "        model_name,\n",
    "        transform=ax[row, -1].transAxes,\n",
    "        fontsize=22,\n",
    "        rotation=270,\n",
    "        va=\"center\",\n",
    "        ha=\"left\",\n",
    "    )\n",
    "    if row != len(MODEL_WHITELIST) - 1:\n",
    "        ax[row, col].set_xticks([])\n",
    "\n",
    "    row += 1\n",
    "\n",
    "\n",
    "t_indices = [TIME[t] for t in TIMESTEPS]\n",
    "n_lines = max(\n",
    "    len(kys)\n",
    "    for cyc in full_diagnostics.values()\n",
    "    for _, diag in cyc.items()\n",
    "    for kys in [diag[\"kyspec\"]]\n",
    ")\n",
    "\n",
    "handles = []\n",
    "labels = []\n",
    "for t, physical_time in enumerate(t_indices):\n",
    "    color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "    handles.append(matplotlib.lines.Line2D([0], [0], color=color, lw=3))\n",
    "    labels.append(rf\"$t={physical_time:.1f}\" + r\"R/V_{\\mathrm{r}}$\")\n",
    "\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc=\"upper center\",\n",
    "    ncol=3,\n",
    "    fontsize=26,\n",
    "    frameon=False,\n",
    "    bbox_to_anchor=(0.5, 1.01),\n",
    ")\n",
    "fig.savefig(\"cascade_iclr.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "247b316b",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_WHITELIST = [\"NF\", \"VQ-VAE + EVA\", \"ZFP\", \"Wavelet\"]\n",
    "TRAJ_WHITELIST = [\n",
    "    #   \"iteration_13\",\n",
    "    #   \"iteration_115\",\n",
    "    #   \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    #   \"iteration_200\",\n",
    "    #   \"iteration_210\",\n",
    "    #   \"iteration_212\",\n",
    "]\n",
    "fig, ax = plt.subplots(\n",
    "    len(MODEL_WHITELIST),\n",
    "    len(TRAJ_WHITELIST),\n",
    "    figsize=(5 * len(TRAJ_WHITELIST), int(1.7 * len(MODEL_WHITELIST))),\n",
    "    gridspec_kw={\"hspace\": 0.0},\n",
    ")\n",
    "plasma = matplotlib.colormaps[\"plasma\"]\n",
    "\n",
    "row = 0\n",
    "\n",
    "for model_name in MODEL_WHITELIST:\n",
    "    cyc_diag = full_diagnostics[model_name]\n",
    "    if \"EVA\" in model_name:\n",
    "        model_name = model_name.replace(\" + EVA\", \"\")\n",
    "\n",
    "    col = 0\n",
    "    for case_name, diag in cyc_diag.items():\n",
    "        if case_name not in TRAJ_WHITELIST:\n",
    "            continue\n",
    "        kys = diag[\"qspec\"]\n",
    "        gt_kys = full_diagnostics[\"GT\"][case_name][\"qspec\"]\n",
    "\n",
    "        n_lines = len(kys)\n",
    "        for t in range(n_lines):\n",
    "            color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "            ax[row, col].plot(\n",
    "                gt_kys[t][1:],\n",
    "                linewidth=2,\n",
    "                color=color,\n",
    "                alpha=0.5,\n",
    "                zorder=0,\n",
    "                linestyle=\"--\",\n",
    "            )\n",
    "            model_curve = kys[t][1:]\n",
    "            gt_curve = gt_kys[t][1:]\n",
    "            model_curve = np.clip(\n",
    "                model_curve, model_curve[model_curve > 0].min(), np.inf\n",
    "            )\n",
    "            ax[row, col].plot(model_curve, linewidth=3, color=color, zorder=1)\n",
    "\n",
    "            # corr, _ = pearsonr(model_curve, gt_curve)\n",
    "            # corr, _ = pearsonr(np.log(model_curve), np.log(gt_curve))\n",
    "            # ax[row, col].text(\n",
    "            #     0.98,\n",
    "            #     0.98 - t * 0.1,\n",
    "            #     rf\"$PC={corr:.2f}$\",\n",
    "            #     color=color,\n",
    "            #     fontsize=14,\n",
    "            #     transform=ax[row, col].transAxes,\n",
    "            #     ha=\"right\",\n",
    "            #     va=\"top\",\n",
    "            # )\n",
    "\n",
    "        ax[row, col].tick_params(axis=\"x\", which=\"both\", labelsize=12)\n",
    "        ax[row, col].tick_params(axis=\"y\", which=\"both\", labelsize=12)\n",
    "\n",
    "        if col == 0 and row == (len(MODEL_WHITELIST) // 2):\n",
    "            ax[row, col].set_ylabel(r\"$Q(k_y)$\", fontsize=26)\n",
    "\n",
    "        if row == len(MODEL_WHITELIST) - 1:\n",
    "            ax[row, col].set_xlabel(r\"$k_y$\", fontsize=26)\n",
    "        else:\n",
    "            ax[row, col].set_xticklabels([])\n",
    "            ax[row, col].tick_params(axis=\"x\", which=\"both\", length=0)\n",
    "\n",
    "        ymin, ymax = ax[row, col].get_ylim()\n",
    "        yticks = [\n",
    "            ymin + 0.1 * (ymax - ymin),\n",
    "            ymin + 0.5 * (ymax - ymin),\n",
    "            ymin + 0.9 * (ymax - ymin),\n",
    "        ]\n",
    "        ax[row, col].set_yticks([round(yt, 1) for yt in yticks])\n",
    "        ax[row, col].grid(True, linestyle=\"--\", alpha=0.6)\n",
    "\n",
    "        ax[row, col].set_xscale(\"log\")\n",
    "        ax[row, col].set_yscale(\"log\")\n",
    "\n",
    "        # ax[row, col].set_xlim(0, 15)\n",
    "\n",
    "        col += 1\n",
    "    col = col - 1\n",
    "\n",
    "    ax[row, -1].text(\n",
    "        1.05,\n",
    "        0.5,\n",
    "        model_name,\n",
    "        transform=ax[row, -1].transAxes,\n",
    "        fontsize=22,\n",
    "        rotation=270,\n",
    "        va=\"center\",\n",
    "        ha=\"left\",\n",
    "    )\n",
    "    if row != len(MODEL_WHITELIST) - 1:\n",
    "        ax[row, col].set_xticks([])\n",
    "\n",
    "    row += 1\n",
    "\n",
    "\n",
    "t_indices = [TIME[t] for t in TIMESTEPS]\n",
    "n_lines = max(\n",
    "    len(kys)\n",
    "    for cyc in full_diagnostics.values()\n",
    "    for _, diag in cyc.items()\n",
    "    for kys in [diag[\"qspec\"]]\n",
    ")\n",
    "\n",
    "handles = []\n",
    "labels = []\n",
    "for t, physical_time in enumerate(t_indices):\n",
    "    color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "    handles.append(matplotlib.lines.Line2D([0], [0], color=color, lw=3))\n",
    "    labels.append(rf\"$t={physical_time:.1f}\" + r\"R/V_{\\mathrm{r}}$\")\n",
    "\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc=\"upper center\",\n",
    "    ncol=3,\n",
    "    fontsize=26,\n",
    "    frameon=False,\n",
    "    bbox_to_anchor=(0.5, 1.01),\n",
    ")\n",
    "fig.savefig(\"qcascade_iclr.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e97cd8d9",
   "metadata": {},
   "source": [
    "## Extra plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db51f9d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_WHITELIST = [\n",
    "    \"GT\",\n",
    "    \"NF\",\n",
    "    \"AE + EVA\",\n",
    "    \"VQ-VAE + EVA\",\n",
    "    \"ZFP\",\n",
    "    \"Wavelet\",\n",
    "    \"PCA\",\n",
    "    \"JPEG2000\",\n",
    "]\n",
    "TRAJ_WHITELIST = [\n",
    "    \"iteration_13\",\n",
    "    \"iteration_115\",\n",
    "    \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    \"iteration_200\",\n",
    "    \"iteration_210\",\n",
    "    \"iteration_212\",\n",
    "]\n",
    "fig, ax = plt.subplots(\n",
    "    len(MODEL_WHITELIST),\n",
    "    len(TRAJ_WHITELIST),\n",
    "    figsize=(5 * len(TRAJ_WHITELIST), int(2.5 * len(MODEL_WHITELIST))),\n",
    "    gridspec_kw={\"hspace\": 0.0},\n",
    ")\n",
    "plasma = matplotlib.colormaps[\"plasma\"]\n",
    "\n",
    "row = 0\n",
    "\n",
    "for model_name in MODEL_WHITELIST:\n",
    "    cyc_diag = full_diagnostics[model_name]\n",
    "    if \"EVA\" in model_name:\n",
    "        model_name = model_name.replace(\" + EVA\", \"\")\n",
    "\n",
    "    col = 0\n",
    "    for case_name, diag in cyc_diag.items():\n",
    "        if case_name not in TRAJ_WHITELIST:\n",
    "            continue\n",
    "        kys = diag[\"kyspec\"]\n",
    "        gt_kys = full_diagnostics[\"GT\"][case_name][\"kyspec\"]\n",
    "\n",
    "        n_lines = len(kys)\n",
    "\n",
    "        for t in range(n_lines):\n",
    "            color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "            model_curve = kys[t][1:]\n",
    "            gt_curve = gt_kys[t][1:]\n",
    "            if model_curve.sum() > 0:\n",
    "                model_curve = np.clip(\n",
    "                    model_curve, model_curve[model_curve > 0].min(), np.inf\n",
    "                )\n",
    "            ax[row, col].plot(model_curve, linewidth=3, color=color, zorder=1)\n",
    "\n",
    "            if model_name != \"GT\":\n",
    "                ax[row, col].plot(\n",
    "                    gt_kys[t][1:],\n",
    "                    linewidth=2,\n",
    "                    color=color,\n",
    "                    alpha=0.5,\n",
    "                    zorder=0,\n",
    "                    linestyle=\"--\",\n",
    "                )\n",
    "                corr, _ = pearsonr(model_curve, gt_curve)\n",
    "                # corr, _ = pearsonr(np.log(model_curve), np.log(gt_curve))\n",
    "                ax[row, col].text(\n",
    "                    0.98,\n",
    "                    0.98 - t * 0.08,\n",
    "                    rf\"$PC={corr:.2f}$\",\n",
    "                    color=color,\n",
    "                    fontsize=14,\n",
    "                    transform=ax[row, col].transAxes,\n",
    "                    ha=\"right\",\n",
    "                    va=\"top\",\n",
    "                )\n",
    "\n",
    "        if col == 0:\n",
    "            ax[row, col].set_ylabel(r\"$|\\phi(k_y)|^2$\", fontsize=24)\n",
    "\n",
    "        if row == len(MODEL_WHITELIST) - 1:\n",
    "            ax[row, col].set_xlabel(r\"$k_y$\", fontsize=24)\n",
    "        else:\n",
    "            ax[row, col].set_xticklabels([])\n",
    "            ax[row, col].tick_params(axis=\"x\", which=\"both\", length=0)\n",
    "\n",
    "        ymin, ymax = ax[row, col].get_ylim()\n",
    "        yticks = [\n",
    "            ymin + 0.1 * (ymax - ymin),\n",
    "            ymin + 0.5 * (ymax - ymin),\n",
    "            ymin + 0.9 * (ymax - ymin),\n",
    "        ]\n",
    "        ax[row, col].set_yticks([round(yt, 1) for yt in yticks])\n",
    "        ax[row, col].grid(True, linestyle=\"--\", alpha=0.6)\n",
    "\n",
    "        ax[row, col].set_xscale(\"log\")\n",
    "        ax[row, col].set_yscale(\"log\")\n",
    "        col += 1\n",
    "    col = col - 1\n",
    "\n",
    "    ax[row, -1].text(\n",
    "        1.05,\n",
    "        0.5,\n",
    "        model_name,\n",
    "        transform=ax[row, -1].transAxes,\n",
    "        fontsize=20,\n",
    "        rotation=270,\n",
    "        va=\"center\",\n",
    "        ha=\"left\",\n",
    "    )\n",
    "    if row != len(MODEL_WHITELIST) - 1:\n",
    "        ax[row, col].set_xticks([])\n",
    "\n",
    "    row += 1\n",
    "\n",
    "\n",
    "t_indices = [TIME[t] for t in TIMESTEPS]\n",
    "n_lines = max(\n",
    "    len(kys)\n",
    "    for cyc in full_diagnostics.values()\n",
    "    for _, diag in cyc.items()\n",
    "    for kys in [diag[\"kyspec\"]]\n",
    ")\n",
    "\n",
    "handles = []\n",
    "labels = []\n",
    "for t, physical_time in enumerate(t_indices):\n",
    "    color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "    handles.append(matplotlib.lines.Line2D([0], [0], color=color, lw=3))\n",
    "    labels.append(rf\"$t={physical_time:.1f}\" + r\"R/V_{\\mathrm{r}}$\")\n",
    "\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc=\"upper center\",\n",
    "    ncol=3,\n",
    "    fontsize=24,\n",
    "    frameon=False,\n",
    "    bbox_to_anchor=(0.5, 0.92),\n",
    "    # bbox_to_anchor=(0.5, 0.98),\n",
    ")\n",
    "fig.savefig(\"extra_cascade_iclr.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f132b39",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_WHITELIST = [\n",
    "    \"GT\",\n",
    "    \"NF\",\n",
    "    \"AE + EVA\",\n",
    "    \"VQ-VAE + EVA\",\n",
    "    \"ZFP\",\n",
    "    \"Wavelet\",\n",
    "    \"PCA\",\n",
    "    \"JPEG2000\",\n",
    "]\n",
    "TRAJ_WHITELIST = [\n",
    "    \"iteration_13\",\n",
    "    \"iteration_115\",\n",
    "    \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    \"iteration_200\",\n",
    "    \"iteration_210\",\n",
    "    \"iteration_212\",\n",
    "]\n",
    "fig, ax = plt.subplots(\n",
    "    len(MODEL_WHITELIST),\n",
    "    len(TRAJ_WHITELIST),\n",
    "    figsize=(5 * len(TRAJ_WHITELIST), int(2.5 * len(MODEL_WHITELIST))),\n",
    "    gridspec_kw={\"hspace\": 0.0},\n",
    ")\n",
    "plasma = matplotlib.colormaps[\"plasma\"]\n",
    "\n",
    "row = 0\n",
    "\n",
    "for model_name in MODEL_WHITELIST:\n",
    "    cyc_diag = full_diagnostics[model_name]\n",
    "    if \"EVA\" in model_name:\n",
    "        model_name = model_name.replace(\" + EVA\", \"\")\n",
    "\n",
    "    col = 0\n",
    "    for case_name, diag in cyc_diag.items():\n",
    "        if case_name not in TRAJ_WHITELIST:\n",
    "            continue\n",
    "        kys = diag[\"qspec\"]\n",
    "        gt_kys = full_diagnostics[\"GT\"][case_name][\"qspec\"]\n",
    "\n",
    "        n_lines = len(kys)\n",
    "\n",
    "        for t in range(n_lines):\n",
    "            color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "            model_curve = kys[t][1:]\n",
    "            gt_curve = gt_kys[t][1:]\n",
    "            if model_curve.sum() > 0:\n",
    "                model_curve = np.clip(\n",
    "                    model_curve, model_curve[model_curve > 0].min(), np.inf\n",
    "                )\n",
    "            ax[row, col].plot(model_curve, linewidth=3, color=color, zorder=1)\n",
    "\n",
    "            if model_name != \"GT\":\n",
    "                ax[row, col].plot(\n",
    "                    gt_kys[t][1:],\n",
    "                    linewidth=2,\n",
    "                    color=color,\n",
    "                    alpha=0.5,\n",
    "                    zorder=0,\n",
    "                    linestyle=\"--\",\n",
    "                )\n",
    "                corr, _ = pearsonr(model_curve, gt_curve)\n",
    "                # corr, _ = pearsonr(np.log(model_curve), np.log(gt_curve))\n",
    "                ax[row, col].text(\n",
    "                    0.98,\n",
    "                    0.98 - t * 0.08,\n",
    "                    rf\"$PC={corr:.2f}$\",\n",
    "                    color=color,\n",
    "                    fontsize=14,\n",
    "                    transform=ax[row, col].transAxes,\n",
    "                    ha=\"right\",\n",
    "                    va=\"top\",\n",
    "                )\n",
    "\n",
    "        if col == 0:\n",
    "            ax[row, col].set_ylabel(r\"$Q(k_y)$\", fontsize=24)\n",
    "\n",
    "        if row == len(MODEL_WHITELIST) - 1:\n",
    "            ax[row, col].set_xlabel(r\"$k_y$\", fontsize=24)\n",
    "        else:\n",
    "            ax[row, col].set_xticklabels([])\n",
    "            ax[row, col].tick_params(axis=\"x\", which=\"both\", length=0)\n",
    "\n",
    "        ymin, ymax = ax[row, col].get_ylim()\n",
    "        yticks = [\n",
    "            ymin + 0.1 * (ymax - ymin),\n",
    "            ymin + 0.5 * (ymax - ymin),\n",
    "            ymin + 0.9 * (ymax - ymin),\n",
    "        ]\n",
    "        ax[row, col].set_yticks([round(yt, 1) for yt in yticks])\n",
    "        ax[row, col].grid(True, linestyle=\"--\", alpha=0.6)\n",
    "\n",
    "        ax[row, col].set_xscale(\"log\")\n",
    "        ax[row, col].set_yscale(\"log\")\n",
    "        col += 1\n",
    "    col = col - 1\n",
    "\n",
    "    ax[row, -1].text(\n",
    "        1.05,\n",
    "        0.5,\n",
    "        model_name,\n",
    "        transform=ax[row, -1].transAxes,\n",
    "        fontsize=20,\n",
    "        rotation=270,\n",
    "        va=\"center\",\n",
    "        ha=\"left\",\n",
    "    )\n",
    "    if row != len(MODEL_WHITELIST) - 1:\n",
    "        ax[row, col].set_xticks([])\n",
    "\n",
    "    row += 1\n",
    "\n",
    "\n",
    "t_indices = [TIME[t] for t in TIMESTEPS]\n",
    "n_lines = max(\n",
    "    len(kys)\n",
    "    for cyc in full_diagnostics.values()\n",
    "    for _, diag in cyc.items()\n",
    "    for kys in [diag[\"kyspec\"]]\n",
    ")\n",
    "\n",
    "handles = []\n",
    "labels = []\n",
    "for t, physical_time in enumerate(t_indices):\n",
    "    color = plasma((t + 0.1) / max(1, n_lines - 0.5))\n",
    "    handles.append(matplotlib.lines.Line2D([0], [0], color=color, lw=3))\n",
    "    labels.append(rf\"$t={physical_time:.1f}\" + r\"R/V_{\\mathrm{r}}$\")\n",
    "\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc=\"upper center\",\n",
    "    ncol=3,\n",
    "    fontsize=24,\n",
    "    frameon=False,\n",
    "    bbox_to_anchor=(0.5, 0.92),\n",
    "    # bbox_to_anchor=(0.5, 0.98),\n",
    ")\n",
    "fig.savefig(\"extra_qcascade_iclr.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67976d9c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mhd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
