{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bbe6e53d",
   "metadata": {},
   "source": [
    "# Extra results: weight compression, latent interpolation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "759e150c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "58f9d897",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2f84380e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neural_fields.nf_utils import compress_weights, load_nf, sample_field\n",
    "from neural_fields.gk_losses import integral_losses\n",
    "from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader\n",
    "from neural_fields.nf_train import train_nf, eval_diagnose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "50d03981",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(gt, pred):\n",
    "    gt, pred = gt.cpu(), pred.cpu()\n",
    "    l1 = torch.mean(torch.abs(gt - pred)).item()\n",
    "    mse = ((pred.cpu() - gt.cpu()) ** 2).mean()\n",
    "    psnr = 10 * torch.log10(gt.max() ** 2 / mse**2)\n",
    "    return l1, psnr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fc5fa37",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "\n",
    "TIMESTEPS = [100, 120, 140, 160, 180, 200]\n",
    "TRAJECTORIES = [\n",
    "    \"iteration_13\",\n",
    "    \"iteration_115\",\n",
    "    \"iteration_131\",\n",
    "    \"iteration_134\",\n",
    "    \"iteration_146\",\n",
    "    \"iteration_148\",\n",
    "    \"iteration_160\",\n",
    "    \"iteration_200\",\n",
    "    \"iteration_210\",\n",
    "    \"iteration_212\",\n",
    "]\n",
    "\n",
    "DELTA = TIMESTEPS[1] - TIMESTEPS[0]\n",
    "\n",
    "DATASETS = {}\n",
    "for traj in TRAJECTORIES:\n",
    "    DATASETS[traj] = {}\n",
    "    for t in TIMESTEPS:\n",
    "        time_a = t\n",
    "        time_b = time_a + DELTA\n",
    "        time_c = time_a + DELTA // 2\n",
    "        DATASETS[traj][t] = {\n",
    "            \"a\": CycloneNFDataset(\n",
    "                trajectory=f\"{traj}.h5\",\n",
    "                timesteps=time_a,\n",
    "                normalize=\"zscore\",\n",
    "                normalize_coords=False,\n",
    "                realpotens=True,\n",
    "            ),\n",
    "            \"b\": CycloneNFDataset(\n",
    "                trajectory=f\"{traj}.h5\",\n",
    "                timesteps=time_b,\n",
    "                normalize=\"zscore\",\n",
    "                normalize_coords=False,\n",
    "                realpotens=True,\n",
    "            ),\n",
    "            \"c\": CycloneNFDataset(\n",
    "                trajectory=f\"{traj}.h5\",\n",
    "                timesteps=time_c,\n",
    "                normalize=\"zscore\",\n",
    "                normalize_coords=False,\n",
    "                realpotens=True,\n",
    "            ),\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90a22de",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"mlp\"\n",
    "POSTFIX = \"x1163\"\n",
    "\n",
    "missing = {}\n",
    "missing_int = {}\n",
    "for traj in TRAJECTORIES:\n",
    "    missing[traj] = []\n",
    "    missing_int[traj] = []\n",
    "    for t in TIMESTEPS:\n",
    "        ckp_name = f\"{MODEL}_{traj}_t{t}_{POSTFIX}.pt\"\n",
    "        if not os.path.exists(f\"../nf_ckps/{ckp_name}\"):\n",
    "            missing[traj].append(t)\n",
    "        if not os.path.exists(f\"../nf_ckps/int_{ckp_name}\"):\n",
    "            missing_int[traj].append(t)\n",
    "missing_int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50a8420f",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {}\n",
    "for traj in TRAJECTORIES:\n",
    "    MODELS[traj] = {}\n",
    "    for t in TIMESTEPS:\n",
    "        ckp_name = f\"int_{MODEL}_{traj}_t{t}_{POSTFIX}.pt\"\n",
    "        MODELS[traj][t] = load_nf(f\"../nf_ckps/{ckp_name}\", device).to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b0be16",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MODELS[\"iteration_13\"][200]\n",
    "\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "model_size = sum(p.nbytes for p in model.parameters())\n",
    "compression = DATASETS[\"iteration_13\"][200][\"a\"].full_df.nbytes / model_size\n",
    "print(f\"Params: {n_params / 1e3:.2f}k, compression: {compression:.2f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23c0cfe0",
   "metadata": {},
   "source": [
    "## Weight space compression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aeca12d",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_diagnose(\n",
    "    model=model,\n",
    "    data=DATASETS[\"iteration_13\"][200][\"a\"],\n",
    "    T=None,\n",
    "    device=device,\n",
    "    metrics_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45931e54",
   "metadata": {},
   "source": [
    "### ZFP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a3773f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_zfp, original_size, compressed_size = compress_weights(\n",
    "    model, method=\"zfp\", tolerance=0.01\n",
    ")\n",
    "\n",
    "weight_compression = original_size / compressed_size\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "data_compression = DATASETS[\"iteration_13\"][200][\"a\"].full_df.numel() / n_params\n",
    "total_compression = data_compression * weight_compression\n",
    "print(f\"Params: {n_params / 1e3:.2f}k\")\n",
    "print(f\"Model original: {original_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Model compressed: {compressed_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Weight compression: {weight_compression:.2f}x\")\n",
    "print(f\"Data compression: {data_compression:.2f}x --> {total_compression:.2f}x\")\n",
    "print()\n",
    "eval_diagnose(\n",
    "    model=model_zfp,\n",
    "    data=DATASETS[\"iteration_13\"][200][\"a\"],\n",
    "    T=None,\n",
    "    device=device,\n",
    "    metrics_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d713c114",
   "metadata": {},
   "source": [
    "### ZipNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0c1d9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_zipnn, original_size, compressed_size = compress_weights(model, method=\"zipnn\")\n",
    "\n",
    "weight_compression = original_size / compressed_size\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "data_compression = DATASETS[\"iteration_13\"][200][\"a\"].full_df.numel() / n_params\n",
    "total_compression = data_compression * weight_compression\n",
    "print(f\"Params: {n_params / 1e3:.2f}k\")\n",
    "print(f\"Model original: {original_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Model compressed: {compressed_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Weight compression: {weight_compression:.2f}x\")\n",
    "print(f\"Data compression: {data_compression:.2f}x --> {total_compression:.2f}x\")\n",
    "print()\n",
    "eval_diagnose(\n",
    "    model=model_zipnn,\n",
    "    data=DATASETS[\"iteration_13\"][200][\"a\"],\n",
    "    T=None,\n",
    "    device=device,\n",
    "    metrics_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4ded49e",
   "metadata": {},
   "source": [
    "### Float8 quantize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73b3bc04",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_quant, original_size, compressed_size = compress_weights(\n",
    "    model, method=\"quantize8\"\n",
    ")\n",
    "\n",
    "weight_compression = original_size / compressed_size\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "data_compression = DATASETS[\"iteration_13\"][200][\"a\"].full_df.numel() / n_params\n",
    "total_compression = data_compression * weight_compression\n",
    "print(f\"Params: {n_params / 1e3:.2f}k\")\n",
    "print(f\"Model original: {original_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Model compressed: {compressed_size / (1024 ** 2):.2f}MB\")\n",
    "print(f\"Weight compression: {weight_compression:.2f}x\")\n",
    "print(f\"Data compression: {data_compression:.2f}x --> {total_compression:.2f}x\")\n",
    "print()\n",
    "eval_diagnose(\n",
    "    model=model_quant,\n",
    "    data=DATASETS[\"iteration_13\"][200][\"a\"],\n",
    "    T=None,\n",
    "    device=device,\n",
    "    metrics_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f8f8cb",
   "metadata": {},
   "source": [
    "### Quantitative eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f4a4e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "with torch.no_grad():\n",
    "    for traj in TRAJECTORIES:\n",
    "        for t in TIMESTEPS:\n",
    "            data = DATASETS[traj][t][\"a\"].to(device)\n",
    "\n",
    "            model_base = MODELS[traj][t].to(device)\n",
    "\n",
    "            # baseline\n",
    "            losses_base = integral_losses(model_base, data=data, device=device)\n",
    "            mse_base = losses_base[\"df loss\"]\n",
    "            psnr_base = 10 * torch.log10(data.full_df.max() ** 2 / mse_base**2)\n",
    "            l1q_base = losses_base[\"flux loss\"]\n",
    "            l1p_base = losses_base[\"phi loss\"]\n",
    "\n",
    "            # ZFP\n",
    "            model_zfp, original_size, compressed_size = compress_weights(\n",
    "                model_base, method=\"zfp\", tolerance=0.001\n",
    "            )\n",
    "            cr_zfp = original_size / compressed_size\n",
    "            losses_zfp = integral_losses(model_zfp, data=data, device=device)\n",
    "            mse_zfp = losses_zfp[\"df loss\"]\n",
    "            psnr_zfp = 10 * torch.log10(data.full_df.max() ** 2 / mse_zfp**2)\n",
    "            l1q_zfp = losses_zfp[\"flux loss\"]\n",
    "            l1p_zfp = losses_zfp[\"phi loss\"]\n",
    "\n",
    "            # ZipNN\n",
    "            model_zipnn, original_size, compressed_size = compress_weights(\n",
    "                model_base, method=\"zipnn\"\n",
    "            )\n",
    "            cr_zip = original_size / compressed_size\n",
    "            losses_zip = integral_losses(model_zipnn, data=data, device=device)\n",
    "            mse_zip = losses_zip[\"df loss\"]\n",
    "            psnr_zip = 10 * torch.log10(data.full_df.max() ** 2 / mse_zip**2)\n",
    "            l1q_zip = losses_zip[\"flux loss\"]\n",
    "            l1p_zip = losses_zip[\"phi loss\"]\n",
    "\n",
    "            results.append(\n",
    "                {\n",
    "                    \"Trajectory\": traj,\n",
    "                    \"Timestep\": t,\n",
    "                    \"PSNR\": psnr_base.item(),\n",
    "                    \"L1Q\": l1q_base.item(),\n",
    "                    \"L1P\": l1p_base.item(),\n",
    "                    \"ZFP CR\": cr_zfp,\n",
    "                    \"ZFP PSNR\": psnr_zfp.item(),\n",
    "                    \"ZFP L1Q\": l1q_zfp.item(),\n",
    "                    \"ZFP L1P\": l1p_zfp.item(),\n",
    "                    \"ZipNN CR\": cr_zip,\n",
    "                    \"ZipNN PSNR\": psnr_zip.item(),\n",
    "                    \"ZipNN L1Q\": l1q_zip.item(),\n",
    "                    \"ZipNN L1P\": l1p_zip.item(),\n",
    "                }\n",
    "            )\n",
    "\n",
    "df_metrics = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61733bde",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_metrics[\"ZFP L1P\"].mean(), df_metrics[\"L1P\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26851454",
   "metadata": {},
   "outputs": [],
   "source": [
    "row = {\n",
    "    \"Extra CR\": (\n",
    "        f\"{df_metrics['ZFP CR'].mean():.1f}$\\\\times$\",\n",
    "        f\"{df_metrics['ZipNN CR'].mean():.1f}$\\\\times$\",\n",
    "    ),\n",
    "    r\"$\\Delta$ PSNR ($\\boldsymbol{f}$)\": (\n",
    "        f\"{((df_metrics['PSNR'] - df_metrics['ZFP PSNR']) / df_metrics['ZFP PSNR'] * 100).mean():.5f}\\\\%\",\n",
    "        f\"{((df_metrics['PSNR'] - df_metrics['ZipNN PSNR']) / df_metrics['ZipNN PSNR'] * 100).mean():.3f}\\\\%\",\n",
    "    ),\n",
    "    r\"$\\Delta$ L1 ($Q$)\": (\n",
    "        f\"{((df_metrics['L1Q'] - df_metrics['ZFP L1Q']) / df_metrics['ZFP L1Q'] * 100).mean():.3f}\\\\%\",\n",
    "        f\"{((df_metrics['L1Q'] - df_metrics['ZipNN L1Q']) / df_metrics['ZipNN L1Q'] * 100).mean():.3f}\\\\%\",\n",
    "    ),\n",
    "    r\"$\\Delta$ L1 ($\\boldsymbol{{\\phi}}$)\": (\n",
    "        f\"{((df_metrics['L1P'] - df_metrics['ZFP L1P']) / df_metrics['ZFP L1P'] * 100).mean():.3f}\\\\%\",\n",
    "        f\"{((df_metrics['L1P'] - df_metrics['ZipNN L1P']) / df_metrics['ZipNN L1P'] * 100).mean():.3f}\\\\%\",\n",
    "    ),\n",
    "}\n",
    "\n",
    "# Build table\n",
    "table = [\n",
    "    r\"\\begin{wraptable}{r}{0.4\\linewidth}\",\n",
    "    r\"\\centering\",\n",
    "    r\"\\vspace{-16px}\",\n",
    "    r\"\\caption{Hybrid compression.\\label{tab:hybrid}}\",\n",
    "    r\"\\vspace{-8px}\",\n",
    "    r\"\\begin{tabular}{lcccc}\",\n",
    "    r\"\\toprule\",\n",
    "    r\"Metric & & \\texttt{ZFP} & \\texttt{ZipNN} \\\\\",\n",
    "    r\"\\midrule\",\n",
    "]\n",
    "\n",
    "for metric, (zfp_val, zip_val) in row.items():\n",
    "    arrow = r\"$\\uparrow$\" if \"CR\" in metric or \"PSNR\" in metric else r\"$\\downarrow$\"\n",
    "    table.append(f\"{metric} & {arrow} & {zfp_val} & {zip_val} \\\\\\\\\")\n",
    "\n",
    "table += [\n",
    "    r\"\\bottomrule\",\n",
    "    r\"\\end{tabular}\",\n",
    "    r\"\\vspace{-8px}\",\n",
    "    r\"\\end{wraptable}\",\n",
    "]\n",
    "\n",
    "table = \"\\n\".join(table)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "802fedd3",
   "metadata": {},
   "source": [
    "## Interpolation experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cf55e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_models(model_a, model_b, alpha: float = 0.5):\n",
    "    if alpha == 0.0:\n",
    "        return deepcopy(model_a)\n",
    "    if alpha == 1.0:\n",
    "        return deepcopy(model_b)\n",
    "    model_interp = deepcopy(model_a)\n",
    "    model_interp.load_state_dict(model_a.state_dict())\n",
    "    state_dict_a = model_a.state_dict()\n",
    "    state_dict_b = model_b.state_dict()\n",
    "    state_dict_interp = {}\n",
    "    for key in state_dict_a.keys():\n",
    "        state_dict_interp[key] = (1 - alpha) * state_dict_a[key] + alpha * state_dict_b[\n",
    "            key\n",
    "        ]\n",
    "    model_interp.load_state_dict(state_dict_interp)\n",
    "    return model_interp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2095cc3",
   "metadata": {},
   "source": [
    "### \"Metalearning\" a shared initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf12c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "from neural_fields.models.mlp import MLPNF\n",
    "\n",
    "\n",
    "total = len(TRAJECTORIES) * len(TIMESTEPS) * 4\n",
    "with tqdm(total=total) as pbar:\n",
    "    for traj in TRAJECTORIES:\n",
    "        model0 = MLPNF(\n",
    "            in_dim=5,\n",
    "            out_dim=2,\n",
    "            n_layers=5,\n",
    "            dim=64,\n",
    "            act_fn=torch.nn.SiLU,\n",
    "            use_checkpoint=False,\n",
    "            skips=True,\n",
    "            embed_type=\"discrete\",\n",
    "        )\n",
    "        optim = torch.optim.AdamW(model0.parameters(), 5e-3)\n",
    "        for timestep in TIMESTEPS:\n",
    "            for _ in range(4):\n",
    "                data = CycloneNFDataset(\n",
    "                    trajectory=traj,\n",
    "                    timesteps=timestep,\n",
    "                    normalize=\"zscore\",\n",
    "                    normalize_coords=False,\n",
    "                    realpotens=True,\n",
    "                )\n",
    "                loader = CycloneNFDataLoader(data, 5096, preload=True, shuffle=True)\n",
    "                model0, _, losses = train_nf(\n",
    "                    model0,\n",
    "                    n_epochs=2,\n",
    "                    data=data,\n",
    "                    loader=loader,\n",
    "                    device=device,\n",
    "                    optim=optim,\n",
    "                    use_tqdm=False,\n",
    "                    use_print=False,\n",
    "                )\n",
    "                pbar.set_postfix(\n",
    "                    {\n",
    "                        \"traj\": traj,\n",
    "                        \"MSE\": losses[\"train/loss\"],\n",
    "                        \"PSNR\": losses[\"val/df psnr\"].item(),\n",
    "                    }\n",
    "                )\n",
    "                pbar.update(1)\n",
    "        torch.save(\n",
    "            model0.state_dict(), f\"../nf_shared_init/{traj.replace('.h5', '')}.pth\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c0150c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "TIME_A = TIMESTEPS[0]\n",
    "TIME_B = TIME_A + DELTA\n",
    "TIME_C = TIME_A + DELTA // 2\n",
    "\n",
    "TRAJ = \"iteration_13\"\n",
    "\n",
    "ALPHAS = [0.0, 0.25, 0.5, 0.75, 1.0]\n",
    "\n",
    "\n",
    "GT_DFS = [\n",
    "    CycloneNFDataset(\n",
    "        f\"{TRAJ}.h5\", timesteps=int(TIME_A + DELTA * a), realpotens=True\n",
    "    ).full_df\n",
    "    for a in ALPHAS\n",
    "]\n",
    "interp_models = {\"a\": MODELS[TRAJ][TIME_A], \"b\": MODELS[TRAJ][TIME_B]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "288c419b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_fields_grid(\n",
    "    dfs,\n",
    "    suptitle=None,\n",
    "    projections=None,\n",
    "    titles=None,\n",
    "    row_labels=None,\n",
    "    cmap=\"RdBu_r\",\n",
    "    use_labels=False,\n",
    "):\n",
    "    phy_times = np.loadtxt(\n",
    "        \"<path>/raw/iteration_13/time.dat\"\n",
    "    )\n",
    "    all_labels = [r\"$v_{\\parallel}$\", r\"$\\mu$\", r\"$s$\", r\"$x$\", r\"$y$\"]\n",
    "\n",
    "    if isinstance(dfs, list):\n",
    "        dfs = torch.stack(dfs, 0)\n",
    "\n",
    "    num_cols = dfs.shape[0]\n",
    "    if projections is None:\n",
    "        projections = [(0, 1, 2), (1, 3, 4)]\n",
    "    num_rows = len(projections)\n",
    "\n",
    "    fig, ax = plt.subplots(num_rows, num_cols, figsize=(2 * num_cols, 0.94 * num_rows))\n",
    "    if num_rows == 1:\n",
    "        ax = ax[None, :]\n",
    "    if num_cols == 1:\n",
    "        ax = ax[:, None]\n",
    "\n",
    "    for i, proj in enumerate(projections):\n",
    "        data_list = [dfs[j].mean(0).mean(proj) for j in range(num_cols)]\n",
    "        vmin, vmax = min(d.min() for d in data_list), max(d.max() for d in data_list)\n",
    "        for j in range(num_cols):\n",
    "            ax[i, j].matshow(\n",
    "                data_list[j], cmap=cmap, vmin=vmin, vmax=vmax, aspect=\"auto\"\n",
    "            )\n",
    "            ax[i, j].set_xticks([])\n",
    "            ax[i, j].set_yticks([])\n",
    "            im = ax[i, j].get_images()[0]\n",
    "            # extent = im.get_extent()\n",
    "            # ax[i, j].set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]) / 2.0))\n",
    "            # ax[i, j].axis(\"off\")\n",
    "            # Column titles (first row)\n",
    "            if i == 0 and titles is not None:\n",
    "                ax[i, j].set_title(\n",
    "                    rf\"$t={phy_times[titles[j]]:.1f}R/V_{{\\mathrm{{r}}}}$\", fontsize=18\n",
    "                )\n",
    "\n",
    "        # Row labels on the left (once per row)\n",
    "        if row_labels is not None:\n",
    "            ax[i, 0].set_ylabel(row_labels[i], fontsize=16)\n",
    "\n",
    "        # Axis labels corresponding to the projection\n",
    "        if use_labels:\n",
    "            label_text = \"/\".join([all_labels[o] for o in range(5) if o not in proj])\n",
    "            fig.text(\n",
    "                0.02,\n",
    "                1.0 * (num_rows - i) / num_rows - 0.25,\n",
    "                label_text,\n",
    "                va=\"center\",\n",
    "                ha=\"center\",\n",
    "                fontsize=18,\n",
    "            )\n",
    "\n",
    "    fig.subplots_adjust(\n",
    "        left=0.06, right=0.96, top=0.95, bottom=0, wspace=0.0, hspace=0.0\n",
    "    )\n",
    "\n",
    "    if suptitle is not None:\n",
    "        fig.text(0.97, 0.5, suptitle, va=\"center\", ha=\"left\", fontsize=20, rotation=-90)\n",
    "    fig.patch.set_alpha(0.0)\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddf7d1e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = []\n",
    "with torch.no_grad():\n",
    "    for a in ALPHAS:\n",
    "        model_a = interpolate_models(interp_models[\"a\"], interp_models[\"b\"], alpha=a)\n",
    "        df = sample_field(\n",
    "            model_a, DATASETS[TRAJ][TIME_A][\"a\"], device, timestep=None\n",
    "        ).cpu()\n",
    "        dfs.append(df)\n",
    "dfs = torch.stack(dfs, 0)\n",
    "fig = plot_fields_grid(\n",
    "    dfs,\n",
    "    titles=[int(TIME_A + DELTA * a) for a in ALPHAS],\n",
    "    suptitle=\"Weight space\",\n",
    "    use_labels=True,\n",
    ")\n",
    "fig.savefig(\"figures/weight_interp.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2cb37a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = []\n",
    "with torch.no_grad():\n",
    "    for a in ALPHAS:\n",
    "        df = (1 - a) * DATASETS[TRAJ][TIME_A][\"a\"].full_df + a * DATASETS[TRAJ][TIME_A][\n",
    "            \"b\"\n",
    "        ].full_df\n",
    "        dfs.append(df)\n",
    "dfs = torch.stack(dfs, 0)\n",
    "fig = plot_fields_grid(dfs, suptitle=\"Data space\", use_labels=True)\n",
    "fig.savefig(\"figures/data_interp.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9069ce58",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = []\n",
    "with torch.no_grad():\n",
    "    for i in range(len(ALPHAS)):\n",
    "        df = GT_DFS[i]\n",
    "        dfs.append(df)\n",
    "dfs = torch.stack(dfs, 0)\n",
    "fig = plot_fields_grid(dfs, suptitle=\"Ground truth\", use_labels=True)\n",
    "fig.savefig(\"figures/gt_interp.pdf\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c978ea02",
   "metadata": {},
   "outputs": [],
   "source": [
    "DELTA = TIMESTEPS[1] - TIMESTEPS[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e6220e0",
   "metadata": {},
   "source": [
    "### Autoencoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c988725",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import load_model\n",
    "\n",
    "ae_root = \"<path>\"\n",
    "file = \"ae_pre/ckp.pth\"\n",
    "\n",
    "ae_ckp = f\"{ae_root}/{file}\"\n",
    "ae, _, cfg = load_model(ae_ckp, device=device)\n",
    "ae = ae.to(device).eval()\n",
    "\n",
    "file = \"vqvae_base/best.pth\"\n",
    "vqvae_ckp = f\"{ae_root}/{file}\"\n",
    "vqvae, _, cfg = load_model(vqvae_ckp, device=device)\n",
    "vqvae = vqvae.to(device).eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3795afac",
   "metadata": {},
   "source": [
    "### Interpolation table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da73dc42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from dataset.cyclone_diff import CycloneDiffusionDataset\n",
    "\n",
    "\n",
    "def compute_metrics(gt, pred):\n",
    "    gt, pred = gt.cpu(), pred.cpu()\n",
    "    l1 = torch.mean(torch.abs(gt - pred)).item()\n",
    "    mse = ((pred.cpu() - gt.cpu()) ** 2).mean()\n",
    "    psnr = 10 * torch.log10(gt.max() ** 2 / mse**2)\n",
    "    return l1, psnr\n",
    "\n",
    "\n",
    "def autoencoder_interp_latent(\n",
    "    model, lat_a, lat_b, condition_a, condition_b, pad_axes, alpha\n",
    "):\n",
    "    zdf = (1 - alpha) * lat_a + alpha * lat_b\n",
    "    condition = (1 - alpha) * condition_a + alpha * condition_b\n",
    "    if hasattr(model, \"quantize\"):\n",
    "        zdf = model.quantize(zdf)\n",
    "    df = model.decode(zdf, pad_axes, condition=condition)[\"df\"].squeeze(0)\n",
    "    df = valdata_ae.denormalize(df=df, file_index=0)\n",
    "    return df[[0, 1]] + df[[2, 3]]\n",
    "\n",
    "\n",
    "results = []\n",
    "with torch.no_grad():\n",
    "    for traj in TRAJECTORIES:\n",
    "        for t in TIMESTEPS[:-1]:\n",
    "            ds_a = DATASETS[traj][t][\"a\"].to(device)\n",
    "            ds_b = DATASETS[traj][t][\"b\"].to(device)\n",
    "            ds_c = DATASETS[traj][t][\"c\"].to(device)\n",
    "\n",
    "            gt = ds_c.full_df\n",
    "\n",
    "            # extremes\n",
    "            gt_a = ds_a.full_df\n",
    "            gt_b = ds_b.full_df\n",
    "            l1_a, psnr_a = compute_metrics(gt, gt_a)\n",
    "\n",
    "            # Interpolated model (midpoint)\n",
    "            model_a = MODELS[traj][t]\n",
    "            model_b = MODELS[traj][t + DELTA]\n",
    "\n",
    "            # weight space\n",
    "            alpha = 0.5\n",
    "            model_mid = interpolate_models(model_a, model_b, alpha)\n",
    "            with torch.no_grad():\n",
    "                pred_mid = sample_field(model_mid, ds_c, device=device, timestep=None)\n",
    "            nf_l1_mid, nf_psnr_mid = compute_metrics(gt, pred_mid)\n",
    "\n",
    "            # data space\n",
    "            gt_avg = (1 - alpha) * gt_a + alpha * gt_b\n",
    "            data_l1_mid, data_psnr_mid = compute_metrics(gt, gt_avg)\n",
    "\n",
    "            # autoencoders\n",
    "            traindata_ae = CycloneDiffusionDataset(\n",
    "                path=cfg.dataset.path,\n",
    "                split=\"train\",\n",
    "                input_fields=[\"df\", \"phi\", \"flux\"],\n",
    "                random_seed=cfg.seed,\n",
    "                normalization=cfg.dataset.normalization,\n",
    "                normalization_scope=cfg.dataset.normalization_scope,\n",
    "                trajectories=cfg.dataset.training_trajectories,\n",
    "                separate_zf=cfg.dataset.separate_zf,\n",
    "                real_potens=True,\n",
    "                cond_filters=vars(cfg.dataset.training_cond_filters),\n",
    "                stage=\"autoencoder\",\n",
    "                conditions=[\"itg\", \"dg\", \"s_hat\", \"q\"],\n",
    "            )\n",
    "            valdata_ae = CycloneDiffusionDataset(\n",
    "                path=\"<path>\",\n",
    "                split=\"train\",\n",
    "                input_fields=[\"df\", \"phi\", \"flux\"],\n",
    "                random_seed=cfg.seed,\n",
    "                normalization=cfg.dataset.normalization,\n",
    "                normalization_scope=cfg.dataset.normalization_scope,\n",
    "                normalization_stats=traindata_ae.norm_stats,\n",
    "                trajectories=[f\"{traj}.h5\"],\n",
    "                separate_zf=cfg.dataset.separate_zf,\n",
    "                real_potens=True,\n",
    "                stage=\"autoencoder\",\n",
    "                conditions=[\"itg\", \"dg\", \"s_hat\", \"q\"],\n",
    "            )\n",
    "\n",
    "            sample_a = valdata_ae[t]\n",
    "            sample_b = valdata_ae[t + DELTA]\n",
    "            df_a = sample_a.df.unsqueeze(0).to(device)\n",
    "            df_b = sample_b.df.unsqueeze(0).to(device)\n",
    "            condition_a = sample_a.conditioning.unsqueeze(0).to(device)\n",
    "            condition_b = sample_b.conditioning.unsqueeze(0).to(device)\n",
    "            ae_zdf_a, ae_condition_a, ae_pad_axes = ae.encode(\n",
    "                df_a, condition=condition_a\n",
    "            )\n",
    "            ae_zdf_b, ae_condition_b, _ = ae.encode(df_b, condition=condition_b)\n",
    "            vqvae_zdf_a, vqvae_condition_a, vqvae_pad_axes = vqvae.encode(\n",
    "                df_a, condition=condition_a\n",
    "            )\n",
    "            vqvae_zdf_b, vqvae_condition_b, _ = vqvae.encode(\n",
    "                df_b, condition=condition_b\n",
    "            )\n",
    "\n",
    "            ae_df_ab = autoencoder_interp_latent(\n",
    "                ae,\n",
    "                lat_a=ae_zdf_a,\n",
    "                lat_b=ae_zdf_b,\n",
    "                condition_a=ae_condition_a,\n",
    "                condition_b=ae_condition_b,\n",
    "                pad_axes=ae_pad_axes,\n",
    "                alpha=0.5,\n",
    "            ).cpu()\n",
    "            vqvae_df_ab = autoencoder_interp_latent(\n",
    "                vqvae,\n",
    "                lat_a=vqvae_zdf_a,\n",
    "                lat_b=vqvae_zdf_b,\n",
    "                condition_a=vqvae_condition_a,\n",
    "                condition_b=vqvae_condition_b,\n",
    "                pad_axes=vqvae_pad_axes,\n",
    "                alpha=0.5,\n",
    "            ).cpu()\n",
    "\n",
    "            ae_l1_mid, ae_psnr_mid = compute_metrics(gt, ae_df_ab)\n",
    "            vqvae_l1_mid, vqvae_psnr_mid = compute_metrics(gt, vqvae_df_ab)\n",
    "\n",
    "            results.append(\n",
    "                {\n",
    "                    \"Trajectory\": traj,\n",
    "                    \"Timestep\": t,\n",
    "                    \"Extremes L1\": l1_a,\n",
    "                    \"Extremes PSNR\": psnr_a,\n",
    "                    \"NF L1\": nf_l1_mid,\n",
    "                    \"NF PSNR\": nf_psnr_mid,\n",
    "                    \"AE L1\": ae_l1_mid,\n",
    "                    \"AE PSNR\": ae_psnr_mid,\n",
    "                    \"VQ-VAE l1\": vqvae_l1_mid,\n",
    "                    \"VQ-VAE PSNR\": vqvae_psnr_mid,\n",
    "                    \"Data L1\": data_l1_mid,\n",
    "                    \"Data PSNR\": data_psnr_mid,\n",
    "                }\n",
    "            )\n",
    "\n",
    "df_metrics = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa56e983",
   "metadata": {},
   "outputs": [],
   "source": [
    "row = {\n",
    "    \"Extremes\": (\n",
    "        f\"{df_metrics['Extremes PSNR'].mean():.1f}\",\n",
    "        f\"{df_metrics['Extremes L1'].mean():.2f}\",\n",
    "    ),\n",
    "    r\"$\\boldsymbol{f}$ (data)\": (\n",
    "        f\"{df_metrics['Data PSNR'].mean():.1f}\",\n",
    "        f\"{df_metrics['Data L1'].mean():.2f}\",\n",
    "    ),\n",
    "    \"NF (weights)\": (\n",
    "        f\"{df_metrics['NF PSNR'].mean():.1f}\",\n",
    "        f\"{df_metrics['NF L1'].mean():.2f}\",\n",
    "    ),\n",
    "    \"AE (latents)\": (\n",
    "        f\"{df_metrics['AE PSNR'].mean():.1f}\",\n",
    "        f\"{df_metrics['AE L1'].mean():.2f}\",\n",
    "    ),\n",
    "    \"VQ-VAE (latents)\": (\n",
    "        f\"{df_metrics['VQ-VAE PSNR'].mean():.1f}\",\n",
    "        f\"{df_metrics['VQ-VAE l1'].mean():.2f}\",\n",
    "    ),\n",
    "}\n",
    "\n",
    "# Build table\n",
    "table = [\n",
    "    r\"\\begin{tabular}{lcc}\",\n",
    "    r\"\\toprule\",\n",
    "    r\"Model & PSNR & L1 \\\\\",\n",
    "    r\"\\midrule\",\n",
    "]\n",
    "\n",
    "for model, (psnr, l1) in row.items():\n",
    "    table.append(f\"{model} & {psnr} & {l1} \\\\\\\\\")\n",
    "\n",
    "table += [\n",
    "    r\"\\bottomrule\",\n",
    "    r\"\\end{tabular}\",\n",
    "]\n",
    "\n",
    "table = \"\\n\".join(table)\n",
    "\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbfe40ef",
   "metadata": {},
   "source": [
    "## Physics interpretation tables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "69bf982a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "27d91dc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>physical_losses</th>\n",
       "      <th>pre_train/loss</th>\n",
       "      <th>pre_val/df loss</th>\n",
       "      <th>pre_val/flux loss</th>\n",
       "      <th>pre_val/phi loss</th>\n",
       "      <th>pre_val/phi mse</th>\n",
       "      <th>pre_val/kxspec loss</th>\n",
       "      <th>pre_val/kyspec loss</th>\n",
       "      <th>pre_val/phi_zf loss</th>\n",
       "      <th>pre_val/qspec loss</th>\n",
       "      <th>...</th>\n",
       "      <th>fine_val/kyspec loss</th>\n",
       "      <th>fine_val/phi_zf loss</th>\n",
       "      <th>fine_val/qspec loss</th>\n",
       "      <th>fine_val/qspec monotonicity loss</th>\n",
       "      <th>fine_val/kyspec monotonicity loss</th>\n",
       "      <th>fine_val/mass loss</th>\n",
       "      <th>fine_val/df psnr</th>\n",
       "      <th>fine_val/phi psnr</th>\n",
       "      <th>fine_val/ky pc</th>\n",
       "      <th>fine_val/q pc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>['int', 'diag', 'mono']</td>\n",
       "      <td>0.144532</td>\n",
       "      <td>0.675091</td>\n",
       "      <td>48.414301</td>\n",
       "      <td>4.386669</td>\n",
       "      <td>49.361196</td>\n",
       "      <td>6.242375</td>\n",
       "      <td>3.128465</td>\n",
       "      <td>4.897076</td>\n",
       "      <td>1.513220</td>\n",
       "      <td>...</td>\n",
       "      <td>0.244808</td>\n",
       "      <td>1.511079</td>\n",
       "      <td>1.405257</td>\n",
       "      <td>0.027829</td>\n",
       "      <td>0.020584</td>\n",
       "      <td>55083.955469</td>\n",
       "      <td>38.276908</td>\n",
       "      <td>5.216967</td>\n",
       "      <td>0.885497</td>\n",
       "      <td>0.957384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>['int', 'diag']</td>\n",
       "      <td>0.144555</td>\n",
       "      <td>0.675612</td>\n",
       "      <td>48.978326</td>\n",
       "      <td>4.225558</td>\n",
       "      <td>47.475914</td>\n",
       "      <td>5.995861</td>\n",
       "      <td>3.005203</td>\n",
       "      <td>4.339094</td>\n",
       "      <td>1.531040</td>\n",
       "      <td>...</td>\n",
       "      <td>0.513344</td>\n",
       "      <td>1.526442</td>\n",
       "      <td>1.392169</td>\n",
       "      <td>0.029726</td>\n",
       "      <td>0.019438</td>\n",
       "      <td>47669.738867</td>\n",
       "      <td>38.301330</td>\n",
       "      <td>4.431185</td>\n",
       "      <td>0.881355</td>\n",
       "      <td>0.956025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>['int']</td>\n",
       "      <td>0.144589</td>\n",
       "      <td>0.676681</td>\n",
       "      <td>47.958145</td>\n",
       "      <td>3.542368</td>\n",
       "      <td>31.265018</td>\n",
       "      <td>3.952453</td>\n",
       "      <td>1.983528</td>\n",
       "      <td>3.931111</td>\n",
       "      <td>1.498965</td>\n",
       "      <td>...</td>\n",
       "      <td>1.609294</td>\n",
       "      <td>2.563313</td>\n",
       "      <td>1.423964</td>\n",
       "      <td>0.038717</td>\n",
       "      <td>0.027225</td>\n",
       "      <td>88080.227734</td>\n",
       "      <td>36.677201</td>\n",
       "      <td>-0.822728</td>\n",
       "      <td>0.863505</td>\n",
       "      <td>0.952384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>['mono']</td>\n",
       "      <td>0.144560</td>\n",
       "      <td>0.678388</td>\n",
       "      <td>48.171551</td>\n",
       "      <td>4.264625</td>\n",
       "      <td>47.247162</td>\n",
       "      <td>5.984286</td>\n",
       "      <td>2.999337</td>\n",
       "      <td>4.397497</td>\n",
       "      <td>1.505866</td>\n",
       "      <td>...</td>\n",
       "      <td>851.945351</td>\n",
       "      <td>50.661923</td>\n",
       "      <td>1.998080</td>\n",
       "      <td>0.000387</td>\n",
       "      <td>0.007085</td>\n",
       "      <td>534017.580990</td>\n",
       "      <td>37.286213</td>\n",
       "      <td>-42.209501</td>\n",
       "      <td>0.718817</td>\n",
       "      <td>0.960670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[]</td>\n",
       "      <td>0.144786</td>\n",
       "      <td>0.677366</td>\n",
       "      <td>48.590229</td>\n",
       "      <td>4.445485</td>\n",
       "      <td>59.039023</td>\n",
       "      <td>7.410323</td>\n",
       "      <td>3.712404</td>\n",
       "      <td>4.499970</td>\n",
       "      <td>1.518661</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>['diag']</td>\n",
       "      <td>0.144750</td>\n",
       "      <td>0.676907</td>\n",
       "      <td>48.944201</td>\n",
       "      <td>4.449937</td>\n",
       "      <td>70.805398</td>\n",
       "      <td>8.912115</td>\n",
       "      <td>4.462749</td>\n",
       "      <td>4.543012</td>\n",
       "      <td>1.529551</td>\n",
       "      <td>...</td>\n",
       "      <td>1.671660</td>\n",
       "      <td>2.715998</td>\n",
       "      <td>1.324488</td>\n",
       "      <td>0.027751</td>\n",
       "      <td>0.017595</td>\n",
       "      <td>54169.543229</td>\n",
       "      <td>38.762988</td>\n",
       "      <td>4.652642</td>\n",
       "      <td>0.873894</td>\n",
       "      <td>0.956328</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6 rows × 34 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           physical_losses  pre_train/loss  pre_val/df loss  \\\n",
       "5  ['int', 'diag', 'mono']        0.144532         0.675091   \n",
       "3          ['int', 'diag']        0.144555         0.675612   \n",
       "1                  ['int']        0.144589         0.676681   \n",
       "4                 ['mono']        0.144560         0.678388   \n",
       "0                       []        0.144786         0.677366   \n",
       "2                 ['diag']        0.144750         0.676907   \n",
       "\n",
       "   pre_val/flux loss  pre_val/phi loss  pre_val/phi mse  pre_val/kxspec loss  \\\n",
       "5          48.414301          4.386669        49.361196             6.242375   \n",
       "3          48.978326          4.225558        47.475914             5.995861   \n",
       "1          47.958145          3.542368        31.265018             3.952453   \n",
       "4          48.171551          4.264625        47.247162             5.984286   \n",
       "0          48.590229          4.445485        59.039023             7.410323   \n",
       "2          48.944201          4.449937        70.805398             8.912115   \n",
       "\n",
       "   pre_val/kyspec loss  pre_val/phi_zf loss  pre_val/qspec loss  ...  \\\n",
       "5             3.128465             4.897076            1.513220  ...   \n",
       "3             3.005203             4.339094            1.531040  ...   \n",
       "1             1.983528             3.931111            1.498965  ...   \n",
       "4             2.999337             4.397497            1.505866  ...   \n",
       "0             3.712404             4.499970            1.518661  ...   \n",
       "2             4.462749             4.543012            1.529551  ...   \n",
       "\n",
       "   fine_val/kyspec loss  fine_val/phi_zf loss  fine_val/qspec loss  \\\n",
       "5              0.244808              1.511079             1.405257   \n",
       "3              0.513344              1.526442             1.392169   \n",
       "1              1.609294              2.563313             1.423964   \n",
       "4            851.945351             50.661923             1.998080   \n",
       "0                   NaN                   NaN                  NaN   \n",
       "2              1.671660              2.715998             1.324488   \n",
       "\n",
       "   fine_val/qspec monotonicity loss  fine_val/kyspec monotonicity loss  \\\n",
       "5                          0.027829                           0.020584   \n",
       "3                          0.029726                           0.019438   \n",
       "1                          0.038717                           0.027225   \n",
       "4                          0.000387                           0.007085   \n",
       "0                               NaN                                NaN   \n",
       "2                          0.027751                           0.017595   \n",
       "\n",
       "   fine_val/mass loss  fine_val/df psnr  fine_val/phi psnr  fine_val/ky pc  \\\n",
       "5        55083.955469         38.276908           5.216967        0.885497   \n",
       "3        47669.738867         38.301330           4.431185        0.881355   \n",
       "1        88080.227734         36.677201          -0.822728        0.863505   \n",
       "4       534017.580990         37.286213         -42.209501        0.718817   \n",
       "0                 NaN               NaN                NaN             NaN   \n",
       "2        54169.543229         38.762988           4.652642        0.873894   \n",
       "\n",
       "   fine_val/q pc  \n",
       "5       0.957384  \n",
       "3       0.956025  \n",
       "1       0.952384  \n",
       "4       0.960670  \n",
       "0            NaN  \n",
       "2       0.956328  \n",
       "\n",
       "[6 rows x 34 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"../grid_search_mlp_pinn.csv\").sort_values(\n",
    "    \"pre_val/df psnr\", ascending=False\n",
    ")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "75b6de19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " & $\\phantom{+}\\mathcal{L}_{\\text{recon}}$ & 38.89 & 48.59 & 4.45 & 3.71 & 1.52 \\\\\n",
      " & \\cellcolor{green!10}{$+ \\mathcal{L}_{\\text{int}}$} & 36.68 & 10.35 & 2.55 & 1.61 & 1.42 \\\\\n",
      " & \\cellcolor{yellow!20}{$+ \\mathcal{L}_{\\text{diag}}$} & 38.76 & 41.39 & 2.25 & 1.67 & 1.32 \\\\\n",
      " & \\cellcolor{blue!10}{$+ \\mathcal{L}_{\\text{grad}}$} & 37.29 & 63.94 & 44.18 & 851.95 & 2 \\\\\n",
      " & $+ \\mathcal{L}_{\\text{PINN}}$ & 38.28 & 28.03 & 1.41 & 0.24 & 1.41 \\\\\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "loss_map = {\n",
    "    \"recon\": 0,\n",
    "    \"int\": 1,\n",
    "    \"diag\": 2,\n",
    "    \"grad\": 4,\n",
    "    \"PINN\": 5\n",
    "}\n",
    "\n",
    "# Define the metrics columns for LaTeX table\n",
    "metrics_cols = {\n",
    "    \"f\": \"df psnr\",\n",
    "    \"Q\": \"flux loss\",\n",
    "    r\"\\phi\": \"phi loss\",\n",
    "    \"k_y^{\\\\text{spec}}\": \"kyspec loss\",\n",
    "    \"Q^{\\\\text{spec}}\": \"qspec loss\"\n",
    "}\n",
    "\n",
    "# Initialize LaTeX table strings\n",
    "latex_rows = []\n",
    "\n",
    "for loss_name, idx in loss_map.items():\n",
    "    row = df.loc[idx]\n",
    "    \n",
    "    # Choose pre_val for recon, fine_val for others\n",
    "    prefix = \"pre_val\" if loss_name == \"recon\" else \"fine_val\"\n",
    "    \n",
    "    # Format PSNR values and other metrics\n",
    "    metrics_values = []\n",
    "    for col_name, metric in metrics_cols.items():\n",
    "        val = row[f\"{prefix}/{metric}\"]\n",
    "        if pd.isna(val):\n",
    "            metrics_values.append(\"\")\n",
    "        else:\n",
    "            # Round and strip trailing zeros\n",
    "            val_str = f\"{val:.2f}\".rstrip(\"0\").rstrip(\".\")\n",
    "            metrics_values.append(val_str)\n",
    "    \n",
    "    # Decide if we color the cell\n",
    "    if loss_name == \"int\":\n",
    "        loss_cell = r\"\\cellcolor{green!10}{$+ \\mathcal{L}_{\\text{int}}$}\"\n",
    "    elif loss_name == \"diag\":\n",
    "        loss_cell = r\"\\cellcolor{yellow!20}{$+ \\mathcal{L}_{\\text{diag}}$}\"\n",
    "    elif loss_name == \"grad\":\n",
    "        loss_cell = r\"\\cellcolor{blue!10}{$+ \\mathcal{L}_{\\text{grad}}$}\"\n",
    "    elif loss_name == \"recon\":\n",
    "        loss_cell = r\"$\\phantom{+}\\mathcal{L}_{\\text{recon}}$\"\n",
    "    else:\n",
    "        loss_cell = r\"$+ \\mathcal{L}_{\\text{PINN}}$\"\n",
    "    \n",
    "    # Build the row string\n",
    "    latex_row = f\" & {loss_cell} & \" + \" & \".join(metrics_values) + r\" \\\\\"\n",
    "    latex_rows.append(latex_row)\n",
    "\n",
    "# Combine rows with multirow\n",
    "latex_nf = \"\\n\".join(latex_rows)\n",
    "print(latex_nf)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ed0e01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = pd.read_csv(\"../grid_search_siren.csv\").sort_values(\n",
    "    \"pre_val/df psnr\", ascending=False\n",
    ")\n",
    "grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d0bde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "caption = \"SIREN grid search combinations.\"\n",
    "label = \"tab:siren_combinations\"\n",
    "\n",
    "df_fmt = grid.copy()\n",
    "df_fmt[\"first_w0\"] = df_fmt[\"first_w0\"].map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    "df_fmt[\"hidden_w0\"] = df_fmt[\"hidden_w0\"].map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    "df_fmt[r\"$w_0^{\\text{initial}}$\"] = df_fmt[\"first_w0\"]\n",
    "df_fmt[r\"$w_0^{\\text{hidden}}$\"] = df_fmt[\"hidden_w0\"]\n",
    "\n",
    "df_fmt[r\"PSNR$_\\boldsymbol{f}$\"] = (\n",
    "    (df_fmt[\"pre_val/df psnr\"] / 4)\n",
    "    .round(2)\n",
    "    .map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    ")\n",
    "\n",
    "df_fmt[\"Embedding\"] = df_fmt[\"embed_type\"].replace(\n",
    "    {\"sincos_continuous\": \"SinCos\", \"discrete\": \"Discrete\", \"linear\": \"Linear\"}\n",
    ")\n",
    "df_fmt[\"Skip\"] = df_fmt[\"skips\"].map({True: \"Yes\", False: \"No\"})\n",
    "df_fmt[\"Learning rate\"] = r\"$5e{-3}$\"\n",
    "df_fmt[\"PSNR$_\\boldsymbol{f}$\"] = (\n",
    "    (df_fmt[\"pre_val/df psnr\"] / 4)\n",
    "    .round(2)\n",
    "    .map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    ")\n",
    "\n",
    "table_df = df_fmt[\n",
    "    [\n",
    "        \"Embedding\",\n",
    "        r\"$w_0^{\\text{initial}}$\",\n",
    "        r\"$w_0^{\\text{hidden}}$\",\n",
    "        \"Skip\",\n",
    "        \"Learning rate\",\n",
    "        r\"PSNR$_\\boldsymbol{f}$\",\n",
    "    ]\n",
    "]\n",
    "\n",
    "latex_code = table_df.to_latex(index=False, escape=False, column_format=\"lccccl\")\n",
    "\n",
    "latex_code = (\n",
    "    \"\\\\begin{table}[h]\\n\"\n",
    "    \"\\\\centering\\n\"\n",
    "    \"\\\\renewcommand{\\\\arraystretch}{1.2}\\n\"\n",
    "    f\"\\\\caption{{{caption}\\\\label{{{label}}}}}\\n\" + latex_code + \"\\\\end{table}\\n\"\n",
    ")\n",
    "print(latex_code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb5b953",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = pd.read_csv(\"../grid_search_wire.csv\").sort_values(\n",
    "    \"pre_val/df psnr\", ascending=False\n",
    ")\n",
    "grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6944c487",
   "metadata": {},
   "outputs": [],
   "source": [
    "caption = \"WIRE grid search combinations.\"\n",
    "label = \"tab:wire_combinations\"\n",
    "\n",
    "df_fmt = grid.copy()\n",
    "df_fmt[\"first_w0\"] = df_fmt[\"first_w0\"].map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    "df_fmt[\"hidden_w0\"] = df_fmt[\"hidden_w0\"].map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    "df_fmt[r\"$w_0^{\\text{initial}}$\"] = df_fmt[\"first_w0\"]\n",
    "df_fmt[r\"$w_0^{\\text{hidden}}$\"] = df_fmt[\"hidden_w0\"]\n",
    "df_fmt[\"Embedding\"] = df_fmt[\"embed_type\"].replace(\n",
    "    {\"sincos_continuous\": \"SinCos\", \"discrete\": \"Discrete\", \"linear\": \"Linear\"}\n",
    ")\n",
    "df_fmt[\"Learning rate\"] = r\"$1e{-2}$\"\n",
    "df_fmt[r\"PSNR$_\\boldsymbol{f}$\"] = (\n",
    "    (df_fmt[\"pre_val/df psnr\"] / 4)\n",
    "    .round(2)\n",
    "    .map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    ")\n",
    "\n",
    "table_df = df_fmt[\n",
    "    [\n",
    "        \"Embedding\",\n",
    "        r\"$w_0^{\\text{initial}}$\",\n",
    "        r\"$w_0^{\\text{hidden}}$\",\n",
    "        \"Learning rate\",\n",
    "        r\"PSNR$_\\boldsymbol{f}$\",\n",
    "    ]\n",
    "]\n",
    "\n",
    "latex_code = table_df.to_latex(index=False, escape=False, column_format=\"lccccl\")\n",
    "\n",
    "latex_code = (\n",
    "    \"\\\\begin{table}[h]\\n\"\n",
    "    \"\\\\centering\\n\"\n",
    "    \"\\\\renewcommand{\\\\arraystretch}{1.2}\\n\"\n",
    "    f\"\\\\caption{{{caption}\\\\label{{{label}}}}}\\n\" + latex_code + \"\\\\end{table}\\n\"\n",
    ")\n",
    "print(latex_code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c295944",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = pd.read_csv(\"../grid_search_mlp_.csv\").sort_values(\n",
    "    \"pre_val/df psnr\", ascending=False\n",
    ")\n",
    "grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15fac31c",
   "metadata": {},
   "outputs": [],
   "source": [
    "caption = \"MLP grid search combinations.\"\n",
    "label = \"tab:mlp_combinations\"\n",
    "\n",
    "df_fmt = grid.copy()\n",
    "df_fmt[\"Activation\"] = df_fmt[\"act_fn\"].replace(\n",
    "    {\"silu\": \"SiLU\", \"relu\": \"ReLU\", \"gelu\": \"GELU\"}\n",
    ")\n",
    "df_fmt[\"Embedding\"] = df_fmt[\"embed_type\"].replace(\n",
    "    {\"sincos_continuous\": \"SinCos\", \"discrete\": \"Discrete\", \"linear\": \"Linear\"}\n",
    ")\n",
    "df_fmt[\"Skip\"] = df_fmt[\"skips\"].map({True: \"Yes\", False: \"No\"})\n",
    "df_fmt[\"Learning rate\"] = r\"$5e{-3}$\"\n",
    "df_fmt[r\"PSNR$_\\boldsymbol{f}$\"] = (\n",
    "    (df_fmt[\"pre_val/df psnr\"] / 4)\n",
    "    .round(2)\n",
    "    .map(lambda x: f\"{x:.2f}\".rstrip(\"0\").rstrip(\".\"))\n",
    ")\n",
    "\n",
    "table_df = df_fmt[[\"Activation\", \"Embedding\", \"Skip\", \"Learning rate\", r\"PSNR$_\\boldsymbol{f}$\"]]\n",
    "\n",
    "latex_code = table_df.to_latex(\n",
    "    index=False, escape=False, column_format=\"lccccl\"  # allow math mode in headers\n",
    ")\n",
    "\n",
    "latex_code = (\n",
    "    \"\\\\begin{table}[h]\\n\"\n",
    "    \"\\\\centering\\n\"\n",
    "    \"\\\\renewcommand{\\\\arraystretch}{1.2}\\n\"\n",
    "    f\"\\\\caption{{{caption}\\\\label{{{label}}}}}\\n\" + latex_code + \"\\\\end{table}\\n\"\n",
    ")\n",
    "print(latex_code)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d137266c",
   "metadata": {},
   "source": [
    "## Timing experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "868b88b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from einops import rearrange\n",
    "\n",
    "import os\n",
    "import io\n",
    "import pywt\n",
    "import zfpy\n",
    "from sklearn.decomposition import PCA\n",
    "import glymur"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ade175a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import time\n",
    "\n",
    "from neural_fields.nf_utils import sample_field\n",
    "from neural_fields.nf_train import train_nf\n",
    "from neural_fields.models.mlp import MLPNF\n",
    "from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader\n",
    "\n",
    "\n",
    "data = CycloneNFDataset(\n",
    "    trajectory=\"iteration_13.h5\",\n",
    "    timesteps=200,\n",
    "    normalize=\"zscore\",\n",
    "    normalize_coords=False,\n",
    "    realpotens=True,\n",
    ")\n",
    "loader = CycloneNFDataLoader(data, 2048, preload=True, shuffle=True, pin_memory=True)\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"discrete\",\n",
    ")\n",
    "\n",
    "optim = torch.optim.AdamW(model.parameters(), 5e-3, weight_decay=1e-7)\n",
    "aux_optim = torch.optim.SGD(model.parameters(), 5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ab39cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "runtimes = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2322b4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.perf_counter_ns()\n",
    "model, _, _ = train_nf(\n",
    "    model,\n",
    "    optim=optim,\n",
    "    n_epochs=5,\n",
    "    data=data,\n",
    "    loader=loader,\n",
    "    device=device,\n",
    "    field_loss=True,\n",
    "    physical_loss=False,\n",
    "    use_print=False,\n",
    "    use_tqdm=False\n",
    ")\n",
    "model, _, _ = train_nf(\n",
    "    model,\n",
    "    optim=optim,\n",
    "    n_epochs=20,\n",
    "    data=data,\n",
    "    loader=loader,\n",
    "    device=device,\n",
    "    aux_optim=aux_optim,\n",
    "    field_loss=False,\n",
    "    physical_loss=True,\n",
    "    use_conflictfree=\"pseudo\",\n",
    "    integral_loss_weight={\"flux\": 1.0, \"phi\": 0.01},\n",
    "    physical_loss_weight={\n",
    "        \"kyspec\": 1.0,\n",
    "        \"qspec\": 1.0,\n",
    "        \"kyspec monotonicity\": 1.0,\n",
    "        \"qspec monotonicity\": 1.0,\n",
    "        \"mass\": 0.0,\n",
    "    },\n",
    "    use_print=False,\n",
    "    use_tqdm=False\n",
    ")\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "_ = sample_field(model, data, device)\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"NF\"] = (compress_time, recon_time)\n",
    "runtimes[\"NF\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "784925c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset.cyclone_diff import CycloneDiffusionDataset\n",
    "from utils import load_model\n",
    "\n",
    "\n",
    "ae_root = \"<path>\"\n",
    "\n",
    "# AE\n",
    "cr_1208 = f\"{ae_root}/20250911_143423/best.pth\"\n",
    "ae, _, cfg = load_model(cr_1208, device=device)\n",
    "cr_1208_peft = f\"{ae_root}/20250920_190044/ckp.pth\"\n",
    "ae, _, ae_cfg = load_model(cr_1208_peft, device=device, load_peft=True)\n",
    "\n",
    "# VQ-VAE\n",
    "cr_77368 = f\"{ae_root}/20250911_143815/best.pth\"\n",
    "vqvae, _, vqvae_cfg = load_model(cr_77368, device=device)\n",
    "cr_77368_peft = f\"{ae_root}/20250919_073133/best.pth\"\n",
    "vqvae, _, vqvae_cfg = load_model(cr_77368_peft, device=device, load_peft=True)\n",
    "\n",
    "ae = ae.to(device)\n",
    "vqvae = vqvae.to(device)\n",
    "valdata_ae = CycloneDiffusionDataset(\n",
    "    path=\"<path>\",\n",
    "    split=\"train\",\n",
    "    input_fields=[\"df\", \"phi\", \"flux\"],\n",
    "    random_seed=cfg.seed,\n",
    "    normalization=None,\n",
    "    trajectories=[\"iteration_13.h5\"],\n",
    "    separate_zf=cfg.dataset.separate_zf,\n",
    "    real_potens=True,\n",
    "    stage=cfg.stage,\n",
    "    conditions=[\"itg\", \"dg\", \"s_hat\", \"q\"],\n",
    ")\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "# ae reconstruct\n",
    "sample = valdata_ae[200]\n",
    "df = sample.df.unsqueeze(0).to(device)\n",
    "condition = sample.conditioning.unsqueeze(0).to(device)\n",
    "z, cond, pad = ae.encode(df, condition=condition)\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "ae_df = ae.decode(z, pad, condition=cond)[\"df\"]\n",
    "if ae_df.shape[0] == 4:\n",
    "    ae_df = ae_df[[0, 1]] + ae_df[[2, 3]]\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"AE\"] = (compress_time, recon_time)\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "# ae reconstruct\n",
    "sample = valdata_ae[200]\n",
    "df = sample.df.unsqueeze(0).to(device)\n",
    "condition = sample.conditioning.unsqueeze(0).to(device)\n",
    "z, cond, pad = vqvae.encode(df, condition=condition)\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "ae_df = vqvae.decode(z, pad, condition=cond)[\"df\"]\n",
    "if ae_df.shape[0] == 4:\n",
    "    ae_df = ae_df[[0, 1]] + ae_df[[2, 3]]\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"VQ-VAE\"] = (compress_time, recon_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d518892a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = data.full_df\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "vp, s = df.shape[1], df.shape[3]\n",
    "df = rearrange(df, \"c vp vm s x y -> c (vp vm) (s y) x\").cpu().numpy()\n",
    "zfp_compressed = zfpy.compress_numpy(df, tolerance=2000)\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "zf_df = zfpy.decompress_numpy(zfp_compressed)\n",
    "zf_df = rearrange(zf_df, \"c (vp vm) (s y) x -> c vp vm s x y\", vp=vp, s=s)\n",
    "zf_df = torch.from_numpy(zf_df)\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"ZFP\"] = (compress_time, recon_time)\n",
    "\n",
    "df = data.full_df\n",
    "\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "df = df.cpu().numpy()\n",
    "vp, vm, s, x, y = df.shape[1:]\n",
    "\n",
    "coeffs = []\n",
    "for c in range(2):\n",
    "    dec = pywt.wavedecn(df[c], wavelet=\"haar\", mode=\"periodization\", level=1)\n",
    "    coeff, slices = pywt.coeffs_to_array(dec)\n",
    "    coeff[np.abs(coeff) < 22] = 0\n",
    "    coeffs.append((coeff, slices))\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "wt_df = []\n",
    "for coeff, slices in coeffs:\n",
    "    recon = pywt.array_to_coeffs(coeff, slices, output_format=\"wavedecn\")\n",
    "    recon = pywt.waverecn(recon, wavelet=\"haar\", mode=\"periodization\")\n",
    "    wt_df.append(recon)\n",
    "\n",
    "wt_df = np.stack(wt_df, axis=0)\n",
    "wt_df = wt_df[:, :vp, :vm, :s, :x, :y]\n",
    "wt_df = torch.from_numpy(wt_df)\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"Wavelet\"] = (compress_time, recon_time)\n",
    "\n",
    "\n",
    "df = data.full_df\n",
    "\n",
    "vp, vm, s, x, y = df.shape[1:]\n",
    "df_np = rearrange(df, \"c vp vm s x y -> c (vp vm s) (x y)\").cpu().numpy()\n",
    "\n",
    "pca_results = []\n",
    "compressed = []\n",
    "compressed_size = 0\n",
    "\n",
    "# ------------------------\n",
    "# compression timer start\n",
    "# ------------------------\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "pcas = []  # keep fitted PCA objects for later reconstruction\n",
    "for c in range(2):  # assume 2 components\n",
    "    pca = PCA(n_components=2)\n",
    "    transformed = pca.fit_transform(df_np[c])\n",
    "    pcas.append(pca)\n",
    "\n",
    "    # store compressed form (components + params)\n",
    "    compressed_version = {\n",
    "        \"components\": transformed,\n",
    "        \"mean\": pca.mean_,\n",
    "        \"explained_variance\": pca.explained_variance_,\n",
    "    }\n",
    "    compressed.append(compressed_version)\n",
    "\n",
    "    compressed_size += (\n",
    "        transformed.nbytes + pca.mean_.nbytes + pca.explained_variance_.nbytes\n",
    "    )\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "# ------------------------\n",
    "# reconstruction timer start\n",
    "# ------------------------\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "for c in range(2):\n",
    "    reconstructed = pcas[c].inverse_transform(compressed[c][\"components\"])\n",
    "    pca_results.append(reconstructed)\n",
    "\n",
    "pca_df = np.stack(pca_results, axis=0)\n",
    "\n",
    "# reshape back\n",
    "pca_df = rearrange(\n",
    "    pca_df,\n",
    "    \"c (vp vm s) (x y) -> c vp vm s x y\",\n",
    "    vp=vp,\n",
    "    vm=vm,\n",
    "    s=s,\n",
    "    x=x,\n",
    "    y=y,\n",
    ")\n",
    "pca_df = torch.from_numpy(pca_df)\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"PCA\"] = (compress_time, recon_time)\n",
    "\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "c, vp, vm, _, x, _ = df.shape\n",
    "df_np = df.cpu().numpy()\n",
    "\n",
    "df_flat = rearrange(df_np, \"c vp vm s x y -> c (vp vm s) (x y)\")\n",
    "\n",
    "compressed_data = []\n",
    "compressed_size = 0.0\n",
    "recon_flat = np.zeros_like(df_flat)\n",
    "\n",
    "for ch in range(c):\n",
    "    slice_ = df_flat[ch]\n",
    "\n",
    "    mn, mx = slice_.min(), slice_.max()\n",
    "    if mx - mn == 0:\n",
    "        norm_slice = np.zeros_like(slice_)\n",
    "    else:\n",
    "        norm_slice = (slice_ - mn) / (mx - mn)\n",
    "\n",
    "    img_uint16 = (norm_slice * 65535).astype(np.uint16)\n",
    "\n",
    "    glymur.Jp2k(\"/tmp/df.jp2\", data=img_uint16, cratios=[100.0 / 0.1])\n",
    "    compressed_data.append({\"bytes\": None, \"min\": mn, \"max\": mx})\n",
    "    compressed_size += os.path.getsize(\"/tmp/df.jp2\")\n",
    "\n",
    "compress_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "start = time.perf_counter_ns()\n",
    "\n",
    "for ch in range(c):\n",
    "    jp2 = glymur.Jp2k(\"/tmp/df.jp2\")\n",
    "    recon_uint16 = jp2[:]\n",
    "    mn, mx = compressed_data[ch][\"min\"], compressed_data[ch][\"max\"]\n",
    "    recon_norm = recon_uint16.astype(np.float32) / 65535.0\n",
    "    recon_flat[ch] = recon_norm * (mx - mn) + mn\n",
    "\n",
    "recon_np = rearrange(\n",
    "    recon_flat, \"c (vp vm s) (x y) -> c vp vm s x y\", vp=vp, vm=vm, x=x\n",
    ")\n",
    "recon_np = torch.from_numpy(recon_np)\n",
    "\n",
    "recon_time = (time.perf_counter_ns() - start) / 1e6\n",
    "\n",
    "runtimes[\"JPEG2000\"] = (compress_time, recon_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "371b37db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'NF': (-96347.158581, -260.17468),\n",
       " 'ZFP': (144.629112, 66.08852),\n",
       " 'Wavelet': (1300.92176, 804.220093),\n",
       " 'PCA': (377.350204, 149.443536),\n",
       " 'JPEG2000': (4173.886618, 261.586194)}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "runtimes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d4ecc12",
   "metadata": {},
   "source": [
    "## Initialization experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e3d256a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = CycloneNFDataset(\n",
    "    trajectory=\"iteration_13.h5\",\n",
    "    timesteps=160,\n",
    "    normalize=\"zscore\",\n",
    "    normalize_coords=False,\n",
    ")\n",
    "loader = CycloneNFDataLoader(data, 2048, preload=True, shuffle=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d8ef10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# grid = data.grid\n",
    "\n",
    "# step = 1 / (2 * grid.shape[-2])\n",
    "\n",
    "# shifted = grid + step\n",
    "\n",
    "# grid = torch.stack([grid, shifted], dim=-2)\n",
    "# grid = grid.flatten(start_dim=-3, end_dim=-2)\n",
    "\n",
    "# data.grid = grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f6da5f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "from neural_fields.models.mlp import MLPNF\n",
    "from neural_fields.models.siren import SIREN\n",
    "from neural_fields.models.wire import WIRE\n",
    "\n",
    "MODELS = {}\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"linear\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren\"] = model\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"sincos_discrete\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren sincos\"] = model\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"discrete\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren nn.Embedding\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"linear\",\n",
    ")\n",
    "MODELS[\"mlp\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"sincos_discrete\",\n",
    ")\n",
    "MODELS[\"mlp sincos\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"discrete\",\n",
    ")\n",
    "MODELS[\"mlp nn.Embedding\"] = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "206e3272",
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops import rearrange\n",
    "\n",
    "from nf_utils import plot_diag, sample_field, plotND, phi_fft\n",
    "from gk_losses import spectra_losses, get_integrals\n",
    "\n",
    "\n",
    "def to_complex(x: torch.Tensor) -> torch.Tensor:\n",
    "    assert x.shape[0] == 2, x.shape\n",
    "    x = rearrange(x, \"c ... -> ... c\").contiguous()\n",
    "    return torch.view_as_complex(x).squeeze()\n",
    "\n",
    "\n",
    "def to_real(x: torch.Tensor) -> torch.Tensor:\n",
    "    return rearrange(torch.view_as_real(x), \"... c -> c ...\").squeeze()\n",
    "\n",
    "\n",
    "def df_fft(df: torch.Tensor, norm: str = \"forward\"):\n",
    "    if df.shape[0] == 4:\n",
    "        df = df[[0, 1]] + df[[2, 3]]\n",
    "    df = to_complex(df)\n",
    "    df = torch.fft.fftn(df, dim=(-5, -4, -3, -2, -1), norm=norm)\n",
    "    df = torch.fft.fftshift(df, dim=(-2,))\n",
    "    return to_real(df)\n",
    "\n",
    "\n",
    "def df_ifft(df: torch.Tensor, norm: str = \"forward\"):\n",
    "    if df.shape[0] == 4:\n",
    "        df = df[[0, 1]] + df[[2, 3]]\n",
    "    df = to_complex(df)\n",
    "    df = torch.fft.ifftshift(df, dim=(-2,))\n",
    "    df = torch.fft.ifftn(df, dim=(-5, -4, -3, -2, -1), norm=norm)\n",
    "    return to_real(df)\n",
    "\n",
    "\n",
    "gt_diagz, pred_diagz = [], []\n",
    "pred_dfz = {}\n",
    "pred_phiz = {}\n",
    "for m, nf in MODELS.items():\n",
    "    data.to(device)\n",
    "    nf.to(device)\n",
    "    with torch.no_grad():\n",
    "        pred_df = sample_field(nf, data, device, timestep=None).to(device)\n",
    "        pred_df = df_fft(pred_df)\n",
    "        pred_dfz[m] = pred_df\n",
    "        pred_df = df_ifft(pred_df)\n",
    "\n",
    "    gt_df = data.full_df.to(device)\n",
    "    pred_phi, (pred_pflux, pred_eflux, _) = get_integrals(\n",
    "        pred_df,\n",
    "        data,\n",
    "        flux_fields=True,\n",
    "        spectral_df=False,\n",
    "        phi_integral=False,\n",
    "    )\n",
    "    pred_phiz[m] = to_real(phi_fft(pred_phi))\n",
    "    gt_phi, (_, gt_eflux, _) = get_integrals(\n",
    "        gt_df.to(device), data, flux_fields=True, spectral_df=False\n",
    "    )\n",
    "    spec_losses, (gt_diag, pred_diag) = spectra_losses(\n",
    "        pred_df.cpu(),\n",
    "        pred_phi.cpu(),\n",
    "        pred_eflux.cpu(),\n",
    "        gt_df.cpu(),\n",
    "        gt_phi.cpu(),\n",
    "        gt_eflux.cpu(),\n",
    "        data.ds,\n",
    "    )\n",
    "    gt_diagz.append(gt_diag)\n",
    "    pred_diagz.append(pred_diag)\n",
    "    plotND(df_fft(pred_df), n=5, title=m)\n",
    "    plotND(torch.log(torch.abs(to_real(phi_fft(pred_phi))) ** 2), n=3, title=m)\n",
    "fig_diag = plot_diag(gt_diagz, pred_diagz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfea56bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(len(pred_dfz), 3, figsize=(15, 3 * len(pred_dfz)))\n",
    "i = 0\n",
    "for m, pd in pred_dfz.items():\n",
    "    # df\n",
    "    x = pd.sum(0).sum((0, 1, 2, 3))\n",
    "    ax[i, 0].set_title(f\"{m} df(ky)\")\n",
    "    ax[i, 0].plot(x.cpu().numpy())\n",
    "    # phi\n",
    "    phi = pred_phiz[m]\n",
    "    x = phi.sum(0).sum((0, 1))\n",
    "    ax[i, 1].set_title(f\"{m} phi(ky)\")\n",
    "    ax[i, 1].plot(x.cpu().numpy())\n",
    "    x = phi.sum(0).sum((0, 1)) ** 2\n",
    "    ax[i, 2].set_title(f\"{m} phi(ky)^2\")\n",
    "    ax[i, 2].plot(x.cpu().numpy())\n",
    "    for j in [1, 2]:\n",
    "        ax[i, j].set_xscale(\"log\")\n",
    "        ax[i, j].set_yscale(\"log\")\n",
    "    i += 1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mhd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
