{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename2 = \"K15\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set working dir\n",
    "directory = \"/system/user/publicdata/gyrokinetics/raw/cyclone4_2\"\n",
    "# Input file names\n",
    "filename = \"Poten00000200\"\n",
    "\n",
    "path = os.path.join(directory, filename)\n",
    "\n",
    "# load data files\n",
    "time = np.loadtxt(os.path.join(directory, \"time.dat\"))\n",
    "sgrid = np.loadtxt(os.path.join(directory, \"sgrid\"))\n",
    "xphi = np.loadtxt(os.path.join(directory, \"xphi\"))\n",
    "yphi = np.loadtxt(os.path.join(directory, \"yphi\"))\n",
    "fluxes = np.loadtxt(os.path.join(directory, \"fluxes.dat\"))\n",
    "kyspec = np.loadtxt(os.path.join(directory, \"kyspec\"))\n",
    "krho = np.loadtxt(os.path.join(directory, \"krho\"))\n",
    "vpgr = np.loadtxt(os.path.join(directory, \"vpgr.dat\"))\n",
    "vperp = np.loadtxt(os.path.join(directory, \"vperp.dat\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of parallel direction grid points\n",
    "ns = sgrid.shape[1] if len(sgrid.shape) > 1 else sgrid.shape[0]\n",
    "\n",
    "# number of x, y grid points (in real space)\n",
    "nx, ny = xphi.shape[1], xphi.shape[0]\n",
    "\n",
    "# number of modes in x and y direction\n",
    "nkx, nky = krho.shape[1], krho.shape[0]\n",
    "\n",
    "# get velocity space resolutions\n",
    "nvpar, nmu = vpgr.shape[1], vpgr.shape[0]\n",
    "\n",
    "# Plot flux trace\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 3))\n",
    "ax.plot(time[0:180], fluxes[0:180, 1], lw=3, c=(32 / 255, 70 / 255, 125 / 255))\n",
    "ax.scatter(\n",
    "    np.linspace(0, 30, 10),\n",
    "    np.random.uniform(0, 25, 10),\n",
    "    marker=\"x\",\n",
    "    c=(182 / 255, 72 / 255, 66 / 255),\n",
    ")\n",
    "ax.grid()\n",
    "ax.set_xlabel(r\"$t$\", fontsize=20)\n",
    "ax.set_ylabel(r\"$\\int \\delta f$\", fontsize=20)\n",
    "ax.tick_params(labelsize=14)\n",
    "fig.savefig(\"fluxes.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import colormaps\n",
    "\n",
    "\n",
    "c_map = colormaps[\"plasma\"]\n",
    "\n",
    "# number of parallel direction grid points\n",
    "ns = sgrid.shape[1] if len(sgrid.shape) > 1 else sgrid.shape[0]\n",
    "\n",
    "# number of x, y grid points (in real space)\n",
    "nx, ny = xphi.shape[1], xphi.shape[0]\n",
    "\n",
    "# number of modes in x and y direction\n",
    "nkx, nky = krho.shape[1], krho.shape[0]\n",
    "\n",
    "# get velocity space resolutions\n",
    "nvpar, nmu = vpgr.shape[1], vpgr.shape[0]\n",
    "\n",
    "# load real space electric potential data\n",
    "filename = \"Poten00000380\"\n",
    "a = np.loadtxt(os.path.join(directory, filename))\n",
    "phi = np.reshape(a, (nx, ns, ny), order=\"F\")\n",
    "\n",
    "# Plot real space potential\n",
    "fig, ax = plt.subplots(1, 2, figsize=(12, 3))\n",
    "fig.subplots_adjust(wspace=0.05)\n",
    "ax[0].matshow(np.squeeze(phi[:, 8, :]).T, cmap=c_map)\n",
    "ax[0].set_title(r\"$\\phi_{pred}$\", fontsize=24)\n",
    "ax[0].set_xlabel(\n",
    "    r\"$x_{\\phi}$\",\n",
    "    fontsize=20,\n",
    "    x=ax[0].get_position().x0 + ax[0].get_position().width + ax[1].get_position().x0,\n",
    ")\n",
    "ax[0].set_ylabel(r\"$y_{\\phi}$\", fontsize=20)\n",
    "ax[0].set_xticks([])\n",
    "ax[0].set_yticks([])\n",
    "\n",
    "filename = \"Poten00000400\"\n",
    "a = np.loadtxt(os.path.join(directory, filename))\n",
    "phi = np.reshape(a, (nx, ns, ny), order=\"F\")\n",
    "ax[1].matshow(np.squeeze(phi[:, 8, :]).T, cmap=c_map)\n",
    "ax[1].set_title(r\"$\\phi_{GT}$\", fontsize=24)\n",
    "# ax[1].set_xlabel(r\"$x_{\\phi}$\", fontsize=20)\n",
    "ax[1].set_xticks([])\n",
    "ax[1].set_yticks([])\n",
    "fig.savefig(\"potentials.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of parallel direction grid points\n",
    "ns = sgrid.shape[1] if len(sgrid.shape) > 1 else sgrid.shape[0]\n",
    "\n",
    "# number of x, y grid points (in real space)\n",
    "nx, ny = xphi.shape[1], xphi.shape[0]\n",
    "\n",
    "# number of modes in x and y direction\n",
    "nkx, nky = krho.shape[1], krho.shape[0]\n",
    "\n",
    "# get velocity space resolutions\n",
    "nvpar, nmu = vpgr.shape[1], vpgr.shape[0]\n",
    "\n",
    "# load real space electric potential data\n",
    "a = np.loadtxt(os.path.join(directory, filename))\n",
    "phi = np.reshape(a, (nx, ns, ny))\n",
    "\n",
    "# Plot real space potential\n",
    "plt.figure()\n",
    "plt.pcolor(xphi, yphi, np.squeeze(phi[:, 8, :]).T, shading=\"auto\")\n",
    "plt.colorbar()\n",
    "plt.title(\"Real Space Potential\")\n",
    "plt.xlabel(\"xphi\")\n",
    "plt.ylabel(\"yphi\")\n",
    "plt.show()\n",
    "plt.savefig(\"potentials.svg\")\n",
    "\n",
    "# Plot flux trace\n",
    "plt.figure()\n",
    "plt.plot(time, fluxes[:, 1])\n",
    "plt.title(\"Flux Trace\")\n",
    "plt.xlabel(\"Time\")\n",
    "plt.ylabel(\"Fluxes\")\n",
    "plt.show()\n",
    "plt.savefig(\"fluxes.svg\")\n",
    "\n",
    "# Load full distribution function data\n",
    "with open(os.path.join(directory, filename2), \"rb\") as fid:\n",
    "    ff = np.fromfile(fid, dtype=np.float64)\n",
    "\n",
    "# Reshape the distribution function (first index is re/im component of fourier weights)\n",
    "f = np.reshape(ff, (2, nvpar, nmu, ns, nkx, nky), order=\"F\")\n",
    "\n",
    "# Plot distribution function in velocity space\n",
    "plt.figure()\n",
    "plt.pcolor(vpgr, vperp, np.squeeze(f[0, :, :, 8, 80, 20]).T, shading=\"auto\")\n",
    "plt.colorbar()\n",
    "plt.title(\"Distribution Function (Velocity Space)\")\n",
    "plt.xlabel(\"vpgr\")\n",
    "plt.ylabel(\"vperp\")\n",
    "plt.show()\n",
    "\n",
    "# Additional plot of the distribution function\n",
    "plt.figure()\n",
    "plt.pcolor(np.squeeze(f[0, :, 2, :, 80, 20]).T, shading=\"auto\")\n",
    "plt.colorbar()\n",
    "plt.title(\"Distribution Function Slice\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import colormaps\n",
    "\n",
    "\n",
    "def force_aspect(ax, aspect=1):\n",
    "    im = ax.get_images()\n",
    "    extent = im[0].get_extent()\n",
    "    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)\n",
    "\n",
    "\n",
    "def distribution_5D(fname, loc):\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "\n",
    "    # Load full distribution function data\n",
    "    with open(os.path.join(directory, fname), \"rb\") as fid:\n",
    "        ff = np.fromfile(fid, dtype=np.float64)\n",
    "\n",
    "    # Reshape the distribution function (first index is re/im component of fourier weights)\n",
    "    f = np.reshape(ff, (2, nvpar, nmu, ns, nkx, nky), order=\"F\").astype(\"float32\")\n",
    "    # f = np.moveaxis(f, 0, -1)\n",
    "    # f = f.copy().view(dtype=np.complex64)\n",
    "    # f = np.fft.ifftn(f, axes=(3, 4))\n",
    "    # f = np.stack([f.real, f.imag]).squeeze()\n",
    "\n",
    "    comb = torch.combinations(torch.arange(5), 2).tolist()\n",
    "\n",
    "    fig, ax = plt.subplots(5, 5, figsize=(20, 20))\n",
    "    c_map = colormaps[\"RdBu\"]\n",
    "    c_map.set_bad(\"k\")\n",
    "\n",
    "    imin = -1\n",
    "    for i, j in comb:\n",
    "        other = tuple([o for o in range(5) if o != i and o != j])\n",
    "        xx = f[0].transpose(*(*other, i, j))\n",
    "        if loc == \"start\":\n",
    "            a, b, c = 2, xx.shape[1] // 2, 2\n",
    "        if loc == \"mid\":\n",
    "            a, b, c = xx.shape[0] // 2, xx.shape[1] // 2, xx.shape[2] // 2\n",
    "        if loc == \"end\":\n",
    "            a, b, c = -2, xx.shape[1] // 2, -2\n",
    "        xx = xx[a, b, c]\n",
    "        ax[i, j].matshow(xx, cmap=c_map)\n",
    "\n",
    "        # if i > imin:\n",
    "        #     ax[i, j].set_ylabel(labels[i], fontsize=30)\n",
    "        #     ax[i, j].set_xlabel(labels[j], fontsize=30)\n",
    "        #     imin = i\n",
    "\n",
    "        force_aspect(ax[i, j])\n",
    "\n",
    "        ax[i, j].set_xticks([])\n",
    "        ax[i, j].set_yticks([])\n",
    "\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            if [i, j] not in comb:\n",
    "                ax[i, j].remove()\n",
    "\n",
    "    fig.savefig(f\"{fname}_{loc}.png\", bbox_inches=\"tight\", transparent=True)\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = distribution_5D(\"K50\", loc=\"start\")\n",
    "_ = distribution_5D(\"K50\", loc=\"mid\")\n",
    "_ = distribution_5D(\"K50\", loc=\"end\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "CKP = \"/system/user/publicwork/fpaische/plasmamodelling/outputs/20250115_150422\"\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from models import get_model\n",
    "\n",
    "from dataset.cyclone import CycloneDataset\n",
    "\n",
    "cfg = OmegaConf.create(yaml.safe_load(open(f\"{CKP}/config.yaml\", \"r\")))\n",
    "\n",
    "cfg.model.swin.itg_conditioning = False\n",
    "\n",
    "data = CycloneDataset(\n",
    "    active_keys=cfg.dataset.active_keys,\n",
    "    split=\"val\",\n",
    "    random_seed=cfg.seed,\n",
    "    normalization=cfg.dataset.normalization,\n",
    "    spatial_ifft=cfg.dataset.spatial_ifft,\n",
    "    bundle_seq_length=cfg.model.bundle_seq_length,\n",
    "    trajectories=cfg.dataset.validation_trajectories,\n",
    ")\n",
    "traindata = []\n",
    "print(f\"Train: {len(traindata)}, Val: {len(data)}\")\n",
    "\n",
    "model = get_model(cfg, data)\n",
    "\n",
    "loaded_ckp = torch.load(f\"{CKP}/best.pth\", map_location=device, weights_only=True)\n",
    "model.load_state_dict(\n",
    "    {k.replace(\"module.\", \"\"): v for k, v in loaded_ckp[\"model_state_dict\"].items()}\n",
    ")\n",
    "\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def mse_time_histogram(model_fn, data):\n",
    "    model_fn.eval()\n",
    "\n",
    "    losses = defaultdict(list)\n",
    "    sample = data[0]\n",
    "    K1 = sample.x.numpy()\n",
    "    x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "    with torch.no_grad():\n",
    "        for idx in tqdm(range(len(data))):\n",
    "            sample = data[idx]\n",
    "            K2 = sample.y.numpy()\n",
    "            # x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "            ts = sample.timestep_index.to(device).unsqueeze(0)\n",
    "            x = model_fn(x, timestep=ts)\n",
    "\n",
    "            mse = np.mean((x.squeeze(0).cpu().detach().numpy() - K2) ** 2)\n",
    "            losses[ts.squeeze().item()].append(mse)\n",
    "\n",
    "    fig = mse_time_histogram_from_losses(losses)\n",
    "\n",
    "    return fig, losses\n",
    "\n",
    "\n",
    "def mse_time_histogram_from_losses(losses):\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "\n",
    "    times = sorted(losses.keys())\n",
    "    losses_mean = [np.mean(losses[t]) for t in times]\n",
    "    losses_std = [np.std(losses[t]) for t in times]\n",
    "\n",
    "    # Bar plot with error bars\n",
    "    ax.bar(times, losses_mean, yerr=losses_std, alpha=0.7, capsize=5, color=\"blue\")\n",
    "    ax.set_xlabel(\"Time Step\")\n",
    "    ax.set_ylabel(\"Mean Squared Error\")\n",
    "    ax.set_title(\"MSE by Time Step\")\n",
    "    ax.grid(True)\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, losses = mse_time_histogram(model, data=traindata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, val_losses = mse_time_histogram(model, data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = mse_time_histogram_from_losses(losses | val_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def force_aspect(ax, aspect=1):\n",
    "    im = ax.get_images()\n",
    "    extent = im[0].get_extent()\n",
    "    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)\n",
    "\n",
    "\n",
    "def distribution_5D(model_fn, idx=0):\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "\n",
    "    model_fn.eval()\n",
    "    sample = traindata[idx]\n",
    "    K2 = sample.y.numpy()\n",
    "    x = sample.x.to(device).unsqueeze(0)\n",
    "    ts = sample.timestep_index.to(device).unsqueeze(0)\n",
    "    f = model_fn(x, timestep=ts).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "    comb = torch.combinations(torch.arange(5), 2).tolist()\n",
    "\n",
    "    fig, ax = plt.subplots(5, 5, figsize=(20, 20))\n",
    "    c_map = colormaps[\"rainbow\"]\n",
    "    c_map.set_bad(\"k\")\n",
    "\n",
    "    imin = -1\n",
    "    for i, j in comb:\n",
    "        other = tuple([o for o in range(5) if o != i and o != j])\n",
    "        xx = f[0].std(other)\n",
    "        xx[xx == 0] = np.nan\n",
    "        ax[i, j].matshow(xx, cmap=c_map)\n",
    "\n",
    "        if i > imin:\n",
    "            ax[i, j].set_ylabel(labels[i], fontsize=20)\n",
    "            ax[i, j].set_xlabel(labels[j], fontsize=20)\n",
    "            imin = i\n",
    "\n",
    "        force_aspect(ax[i, j])\n",
    "\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            if [i, j] not in comb:\n",
    "                ax[i, j].remove()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = distribution_5D(model, 10)\n",
    "# fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import animation\n",
    "\n",
    "\n",
    "def animate_5D(model_fn, title=\"\", frames=10, start=55, idx=0):\n",
    "    plt.rcParams[\"animation.html\"] = \"jshtml\"\n",
    "    plt.ioff()\n",
    "    plt.gca().set_aspect(\"equal\")\n",
    "\n",
    "    fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n",
    "    fig.tight_layout()\n",
    "    fig.suptitle(title)\n",
    "\n",
    "    model_fn.eval()\n",
    "    sample = traindata[idx]\n",
    "    K1 = sample.x.numpy()\n",
    "    K2 = sample.y.numpy()\n",
    "    x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "    f = model_fn(x).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "    gt_vmax, gt_vmin = K2[0, :, 7, :, 85, :].max(), K2[0, :, 7, :, 85, :].min()\n",
    "    pred_vmax, pred_vmin = f[0, :, 7, :, 85, :].max(), f[0, :, 7, :, 85, :].min()\n",
    "\n",
    "    # Initial plots to set up the colorbars\n",
    "    gt_im = ax[0].matshow(K2[0, start, 7, :, 85, :], vmax=gt_vmax, vmin=gt_vmin)\n",
    "    pred_im = ax[1].matshow(f[0, start, 7, :, 85, :], vmax=pred_vmax, vmin=pred_vmin)\n",
    "\n",
    "    # Adding colorbars\n",
    "    cbar_gt = fig.colorbar(\n",
    "        gt_im, ax=ax[0], orientation=\"vertical\", fraction=0.046, pad=0.04\n",
    "    )\n",
    "    cbar_gt.set_label(\"Ground Truth\", fontsize=12)\n",
    "\n",
    "    cbar_pred = fig.colorbar(\n",
    "        pred_im, ax=ax[1], orientation=\"vertical\", fraction=0.046, pad=0.04\n",
    "    )\n",
    "    cbar_pred.set_label(\"Prediction\", fontsize=12)\n",
    "\n",
    "    def animate(t):\n",
    "        gt3 = K2[0, start + t, 7, :, 85, :]\n",
    "        pred3 = f[0, start + t, 7, :, 85, :]\n",
    "\n",
    "        gt_im.set_array(gt3)\n",
    "        pred_im.set_array(pred3)\n",
    "\n",
    "    return animation.FuncAnimation(fig, animate, frames=frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani = animate_5D(model, \"\", frames=32, start=0, idx=0)\n",
    "ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = animation.PillowWriter(fps=8, bitrate=400)\n",
    "ani.save(\"shift2.gif\", writer=writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "\n",
    "\n",
    "def animate_5D2(model_fn, title=\"\", frames=10, start=55, idx=0):\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "\n",
    "    plt.rcParams[\"animation.html\"] = \"jshtml\"\n",
    "    plt.ioff()\n",
    "    plt.gca().set_aspect(\"equal\")\n",
    "\n",
    "    fig, ax = plt.subplots(2, 5, figsize=(15, 10))\n",
    "    fig.tight_layout()\n",
    "    fig.suptitle(title)\n",
    "\n",
    "    model_fn.eval()\n",
    "    sample = data[idx]\n",
    "    K1 = sample.x.numpy()\n",
    "    K2 = sample.y.numpy()\n",
    "    x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "    f = model_fn(x).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "    # Compute limits for each subplot\n",
    "    gt_limits = {\n",
    "        \"vpar\": (K2[:, :, 0, :, 85, :].max(), K2[:, :, 0, :, 85, :].min()),\n",
    "        \"mu\": (K2[:, 0, :, :, 85, :].max(), K2[:, 0, :, :, 85, :].min()),\n",
    "        \"s\": (K2[:, :, :, :, 85, :].max(), K2[:, :, :, :, 85, :].min()),\n",
    "        \"x\": (K2[:, :, :, :, 85, :].max(), K2[:, :, :, :, 85, :].min()),\n",
    "        \"y\": (K2[:, :, :, :, :, :].max(), K2[:, :, :, :, :, :].min()),\n",
    "    }\n",
    "    pred_limits = {\n",
    "        \"vpar\": (f[:, :, 0, :, 85, :].max(), f[:, :, 0, :, 85, :].min()),\n",
    "        \"mu\": (f[:, 0, :, :, 85, :].max(), f[:, 0, :, :, 85, :].min()),\n",
    "        \"s\": (f[:, :, :, :, 85, :].max(), f[:, :, :, :, 85, :].min()),\n",
    "        \"x\": (f[:, :, :, :, 85, :].max(), f[:, :, :, :, 85, :].min()),\n",
    "        \"y\": (f[:, :, :, :, :, :].max(), f[:, :, :, :, :, :].min()),\n",
    "    }\n",
    "\n",
    "    # Initialize matshow and colorbars\n",
    "    ims = []\n",
    "    cbars = []\n",
    "    for row in range(2):\n",
    "        for col in range(5):\n",
    "            label = labels[col]\n",
    "            vmin, vmax = gt_limits[label] if row == 0 else pred_limits[label]\n",
    "            im = ax[row, col].matshow(\n",
    "                K2[0, 0, 0, :, 85, :] if row == 0 else f[0, 0, 0, :, 85, :],\n",
    "                vmin=vmin,\n",
    "                vmax=vmax,\n",
    "            )\n",
    "            ims.append(im)\n",
    "            cbar = fig.colorbar(\n",
    "                im, ax=ax[row, col], orientation=\"vertical\", fraction=0.046, pad=0.04\n",
    "            )\n",
    "            cbar.set_label(\n",
    "                f\"{label} ({'Ground Truth' if row == 0 else 'Prediction'})\", fontsize=12\n",
    "            )\n",
    "            cbars.append(cbar)\n",
    "\n",
    "            # Set labels\n",
    "            ax[row, col].set_xlabel(label, fontsize=15)\n",
    "            ax[row, col].set_ylabel(\"s\" if col == 2 else \"y\", fontsize=15)\n",
    "\n",
    "    def animate(t):\n",
    "        if start + t < K2.shape[2]:\n",
    "            ims[0].set_array(K2[0, :, 0, start + t, 85, :])\n",
    "            ims[5].set_array(f[0, :, 0, start + t, 85, :])\n",
    "\n",
    "        if start + t < K2.shape[3]:\n",
    "            ims[1].set_array(K2[0, 0, :, start + t, 85, :])\n",
    "            ims[6].set_array(f[0, 0, :, start + t, 85, :])\n",
    "\n",
    "        if start + t < K2.shape[0]:\n",
    "            ims[2].set_array(K2[0, start + t, 0, :, 85, :])\n",
    "            ims[7].set_array(f[0, start + t, 0, :, 85, :])\n",
    "\n",
    "        if start + t < K2.shape[-1]:\n",
    "            ims[3].set_array(K2[0, :, 0, :, 85, start + t])\n",
    "            ims[8].set_array(f[0, :, 0, :, 85, start + t])\n",
    "\n",
    "        ims[4].set_array(K2[0, 0, 0, :, start + t, :])\n",
    "        ims[9].set_array(f[0, 0, 0, :, start + t, :])\n",
    "\n",
    "    return animation.FuncAnimation(fig, animate, frames=frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani = animate_5D2(model, \"\", frames=5, start=0)\n",
    "ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = animation.PillowWriter(fps=4, bitrate=400)\n",
    "ani.save(\"big_ani.gif\", writer=writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def force_aspect(ax, aspect=1):\n",
    "    im = ax.get_images()\n",
    "    extent = im[0].get_extent()\n",
    "    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)\n",
    "\n",
    "\n",
    "def plot4x4(f, title=\"\", mark_bad=False):\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "    comb = torch.combinations(torch.arange(5), 2).tolist()\n",
    "\n",
    "    fig, ax = plt.subplots(5, 5, figsize=(20, 20))\n",
    "    fig.suptitle(title)\n",
    "    c_map = colormaps[\"coolwarm\"]\n",
    "    c_map.set_bad(\"k\")\n",
    "\n",
    "    for i, j in comb:\n",
    "        other = tuple([o for o in range(5) if o != i and o != j])\n",
    "        xx = f[0].mean(other)\n",
    "        if mark_bad:\n",
    "            xx_std = f[0].std(other)\n",
    "            xx[xx_std == 0] = np.nan\n",
    "\n",
    "        im00 = ax[i, j].imshow(xx.T, cmap=c_map)\n",
    "\n",
    "        fig.colorbar(im00, ax=ax[i, j])\n",
    "        ax[i, j].set_xlabel(labels[i])\n",
    "        ax[i, j].set_ylabel(labels[j])\n",
    "        force_aspect(ax[i, j])\n",
    "\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            if [i, j] not in comb:\n",
    "                ax[i, j].remove()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "K1, _, K2, _ = data[10]\n",
    "K1 = K1.numpy()\n",
    "K2 = K2.numpy()\n",
    "x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "f = model(x).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "# f = np.zeros_like(f)\n",
    "\n",
    "mse = K2 - f\n",
    "\n",
    "plot4x4(K2, \"GT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot4x4(f, f\"PRED (MSE = {mse.mean():.4f})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot4x4(mse, f\"MSE = {mse.mean():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import colormaps\n",
    "\n",
    "\n",
    "def force_aspect(ax, aspect=1):\n",
    "    im = ax.get_images()\n",
    "    extent = im[0].get_extent()\n",
    "    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)\n",
    "\n",
    "\n",
    "def plot4x4_sided(x1, x2, title=\"\", mark_bad=False):\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "    comb = torch.combinations(torch.arange(5), 2).tolist()\n",
    "\n",
    "    fig, ax = plt.subplots(5, 5, figsize=(30, 12))\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            ax[i, j].remove()\n",
    "\n",
    "    fig.tight_layout()\n",
    "    fig.suptitle(title)\n",
    "    c_map = colormaps[\"coolwarm\"]\n",
    "    c_map.set_bad(\"k\")\n",
    "\n",
    "    for i, j in comb:\n",
    "        other = tuple([o for o in range(5) if o != i and o != j])\n",
    "        x1_plot = x1[0].mean(other)\n",
    "        x2_plot = x2[0].mean(other)\n",
    "        x1_vmax, x1_vmin = x1_plot.max(), x1_plot.min()\n",
    "        x2_vmax, x2_vmin = x2_plot.max(), x2_plot.min()\n",
    "\n",
    "        if mark_bad:\n",
    "            x1_std = x1.std(other)\n",
    "            x2_std = x2.std(other)\n",
    "            x1_plot[x1_std == 0] = np.nan\n",
    "            x2_plot[x2_std == 0] = np.nan\n",
    "\n",
    "        # Clear the axis and directly plot two images side by side\n",
    "        ax_ij = ax[i, j]\n",
    "        # ax_ij.clear()\n",
    "\n",
    "        # Get the position of the original axis\n",
    "        pos = ax_ij.get_position()\n",
    "\n",
    "        # Create two new axes within the same space as the original subplot\n",
    "        width = pos.width / 2  # Split the width into two halves\n",
    "        ax1 = fig.add_axes([pos.x0, pos.y0, width, pos.height])\n",
    "        ax2 = fig.add_axes([pos.x0 + width, pos.y0, width, pos.height])\n",
    "\n",
    "        # Plot x1 and xp side by side\n",
    "        im1 = ax1.imshow(x1_plot, cmap=c_map, vmax=x1_vmax, vmin=x1_vmin)\n",
    "        im2 = ax2.imshow(x2_plot, cmap=c_map, vmax=x2_vmax, vmin=x2_vmin)\n",
    "\n",
    "        fig.colorbar(im1, ax=ax1)\n",
    "        fig.colorbar(im2, ax=ax2)\n",
    "\n",
    "        if i == 0:\n",
    "            # Set axis labels\n",
    "            ax1.set_title(\"GT\")\n",
    "            ax2.set_title(\"PRED\")\n",
    "\n",
    "        ax1.set_xlabel(labels[i])\n",
    "        ax1.set_ylabel(labels[j])\n",
    "        ax2.set_xlabel(labels[i])\n",
    "        ax2.set_ylabel(labels[j])\n",
    "\n",
    "        # Force aspect ratio\n",
    "        force_aspect(ax1)\n",
    "        force_aspect(ax2)\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "sample = data[0]\n",
    "K1 = sample.x.numpy()\n",
    "K2 = sample.y.numpy()\n",
    "x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "ts = sample.timestep_index.to(device).unsqueeze(0)\n",
    "model = model.cpu()\n",
    "f = model.patch_decode(*model.patch_encode(x.cpu())).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "fig = plot4x4_sided(K1, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "sample = data[1]\n",
    "K1 = sample.x.numpy()\n",
    "K2 = sample.y.numpy()\n",
    "x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "ts = sample.timestep_index.to(device).unsqueeze(0)\n",
    "f = model.cpu()(x.cpu(), timestep=ts.cpu()).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "fig = plot4x4_sided(K1, K2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "sample = data[2]\n",
    "K1 = sample.x.numpy()\n",
    "K2 = sample.y.numpy()\n",
    "x = torch.tensor(K1).to(device).unsqueeze(0)\n",
    "ts = sample.timestep_index.to(device).unsqueeze(0)\n",
    "f = model.cpu()(x.cpu(), timestep=ts.cpu()).squeeze(0).cpu().detach().numpy()\n",
    "\n",
    "fig = plot4x4_sided(K1, K2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import animation\n",
    "\n",
    "\n",
    "def force_aspect(ax, aspect=1):\n",
    "    im = ax.get_images()\n",
    "    extent = im[0].get_extent()\n",
    "    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)\n",
    "\n",
    "\n",
    "def plot4x4_animate(model_fn, title=\"\", mark_bad=False, frames=5, start_idx=0):\n",
    "    plt.rcParams[\"animation.html\"] = \"jshtml\"\n",
    "    plt.ioff()\n",
    "\n",
    "    labels = [\"vpar\", \"mu\", \"s\", \"x\", \"y\"]\n",
    "    comb = torch.combinations(torch.arange(5), 2).tolist()\n",
    "\n",
    "    fig, ax = plt.subplots(5, 5, figsize=(30, 12))\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            ax[i, j].remove()\n",
    "\n",
    "    # fig.tight_layout()\n",
    "    fig.suptitle(title)\n",
    "    c_map = colormaps[\"coolwarm\"]\n",
    "    c_map.set_bad(\"k\")\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    x0 = data[start_idx].x\n",
    "    preds = []\n",
    "    preds = [x0.to(device).unsqueeze(0)]\n",
    "    with torch.no_grad():\n",
    "        for i in range(frames):\n",
    "            xp = model_fn(preds[-1])\n",
    "            preds.append(xp)\n",
    "\n",
    "    preds = preds[1:]\n",
    "    preds = [p.squeeze(0).cpu().detach().numpy() for p in preds]\n",
    "\n",
    "    def animate(t):\n",
    "        x1 = data[start_idx + t].y.numpy()\n",
    "        xp = preds[t]\n",
    "        ts = data[start_idx + t].timestep.numpy().item()\n",
    "        fig.suptitle(f\"ts={ts:.2f}\", fontsize=30)\n",
    "\n",
    "        for i, j in comb:\n",
    "            other = tuple([o for o in range(5) if o != i and o != j])\n",
    "\n",
    "            x1_plot = x1[0].mean(other)\n",
    "            xp_plot = xp[0].mean(other)\n",
    "            x1_vmax, x1_vmin = x1_plot.max(), x1_plot.min()\n",
    "            xp_vmax, xp_vmin = xp_plot.max(), xp_plot.min()\n",
    "\n",
    "            if mark_bad:\n",
    "                x1_std = x1.std(other)\n",
    "                xp_std = xp.std(other)\n",
    "                x1_plot[x1_std == 0] = np.nan\n",
    "                xp_plot[xp_std == 0] = np.nan\n",
    "\n",
    "            # Clear the axis and directly plot two images side by side\n",
    "            ax_ij = ax[i, j]\n",
    "            # ax_ij.clear()\n",
    "\n",
    "            # Get the position of the original axis\n",
    "            pos = ax_ij.get_position()\n",
    "\n",
    "            # Create two new axes within the same space as the original subplot\n",
    "            width = pos.width / 2  # Split the width into two halves\n",
    "            ax1 = fig.add_axes([pos.x0, pos.y0, width, pos.height])\n",
    "            ax2 = fig.add_axes([pos.x0 + width, pos.y0, width, pos.height])\n",
    "\n",
    "            # Plot x1 and xp side by side\n",
    "            im1 = ax1.imshow(x1_plot, cmap=c_map, vmax=x1_vmax, vmin=x1_vmin)\n",
    "            im2 = ax2.imshow(xp_plot, cmap=c_map, vmax=xp_vmax, vmin=xp_vmin)\n",
    "\n",
    "            if i == 0:\n",
    "                # Set axis labels\n",
    "                ax1.set_title(\"GT\")\n",
    "                ax2.set_title(\"PRED\")\n",
    "            ax1.set_xlabel(labels[i])\n",
    "            ax1.set_ylabel(labels[j])\n",
    "            ax2.set_xlabel(labels[i])\n",
    "            ax2.set_ylabel(labels[j])\n",
    "\n",
    "            # Force aspect ratio\n",
    "            force_aspect(ax1)\n",
    "            force_aspect(ax2)\n",
    "\n",
    "    return animation.FuncAnimation(fig, animate, frames=frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani = plot4x4_animate(model, frames=10, start_idx=0)\n",
    "ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = animation.PillowWriter(fps=1, bitrate=600)\n",
    "ani.save(\"sided.gif\", writer=writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mhd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
