{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8292b839-76ba-4243-b79e-7fb268a4e6ac",
   "metadata": {},
   "source": [
    "# OmniField (TQV → TQV) — Notebook Guide\n",
    "\n",
    "This notebook trains/evaluates a **TQV → TQV** setup using ** OmniField** on the ClimSim dataset. It’s configured for perlmutter environment and a **chronological split**.\n",
    "\n",
    "---\n",
    "\n",
    "For now use this as a way to view the model code (need to download the samples in order to train)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f526f46-2149-42f8-9f38-87685d1caece",
   "metadata": {},
   "source": [
    "Dataloader for Venn Diagram setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c484973c-98f0-4049-bac8-8dae8cda2891",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random, numpy as np, torch\n",
    "from torch.utils.data import Dataset\n",
    "import netCDF4 as nc\n",
    "\n",
    "def load_idx2latlon(grid_meta_path):\n",
    "    with nc.Dataset(grid_meta_path) as ds:\n",
    "        lat = ds.variables['lat'][:]\n",
    "        lon = ds.variables['lon'][:]\n",
    "    return [(float(lat[i]), float(lon[i])) for i in range(len(lat))]\n",
    "\n",
    "REGIONS = ['T','Q','V','TQ','TV','QV','TQV']\n",
    "REG2BITS = {\n",
    "    'T':   (1,0,0),\n",
    "    'Q':   (0,1,0),\n",
    "    'V':   (0,0,1),\n",
    "    'TQ':  (1,1,0),\n",
    "    'TV':  (1,0,1),\n",
    "    'QV':  (0,1,1),\n",
    "    'TQV': (1,1,1),\n",
    "}\n",
    "\n",
    "def _assign_venn_indices(n_points, sparsity, triple_fraction, seed, triple_fixed_count=None):\n",
    "    rng = np.random.RandomState(seed)\n",
    "    K_mod = int(round(sparsity * n_points))\n",
    "    assert K_mod > 0, \"sparsity too small\"\n",
    "    a = int(triple_fixed_count if triple_fixed_count is not None else round(triple_fraction * K_mod))\n",
    "    assert 0 <= a <= K_mod, \"invalid triple size\"\n",
    "\n",
    "    # distribute counts per region with t > p\n",
    "    p = (K_mod - a) // 4\n",
    "    t = K_mod - a - 2 * p\n",
    "    if t <= p:\n",
    "        p = max(0, p - 1)\n",
    "        t = K_mod - a - 2 * p\n",
    "    assert t >= 0 and p >= 0\n",
    "    assert a + 2 * p + t == K_mod  # per-modality total\n",
    "\n",
    "    cnt = {'TQV': a, 'TQ': p, 'TV': p, 'QV': p, 'T': t, 'Q': t, 'V': t}\n",
    "\n",
    "    perm = rng.permutation(n_points)\n",
    "    cursor = 0\n",
    "    region_indices = {}\n",
    "    for r in ['TQV','TQ','TV','QV','T','Q','V']:\n",
    "        k = cnt[r]\n",
    "        if k > 0:\n",
    "            region_indices[r] = perm[cursor:cursor+k]\n",
    "            cursor += k\n",
    "        else:\n",
    "            region_indices[r] = np.empty((0,), dtype=int)\n",
    "\n",
    "    mask_T = np.zeros(n_points, dtype=bool)\n",
    "    mask_Q = np.zeros(n_points, dtype=bool)\n",
    "    mask_V = np.zeros(n_points, dtype=bool)\n",
    "    for r, idxs in region_indices.items():\n",
    "        bT, bQ, bV = REG2BITS[r]\n",
    "        if bT: mask_T[idxs] = True\n",
    "        if bQ: mask_Q[idxs] = True\n",
    "        if bV: mask_V[idxs] = True\n",
    "\n",
    "    assert mask_T.sum() == K_mod and mask_Q.sum() == K_mod and mask_V.sum() == K_mod\n",
    "    assert (mask_T & mask_Q & mask_V).sum() == a\n",
    "\n",
    "    union_mask = mask_T | mask_Q | mask_V\n",
    "    fixed_idx = np.sort(np.where(union_mask)[0])\n",
    "\n",
    "    inv = {}\n",
    "    for r, idxs in region_indices.items():\n",
    "        for gi in idxs:\n",
    "            inv[gi] = r\n",
    "    region_of_local = [inv[gi] for gi in fixed_idx]\n",
    "\n",
    "    return fixed_idx, region_of_local, {'T': mask_T, 'Q': mask_Q, 'V': mask_V}\n",
    "\n",
    "\n",
    "class ClimSimTQVForecastVennFixed(Dataset):\n",
    "    \"\"\"\n",
    "    Venn dataset:\n",
    "      - Fixed Venn partition (masks for T/Q/V).\n",
    "      - Inputs come from `input_region`: \"union\" (default) or \"triple\".\n",
    "      - Inputs include only modalities requested by `input_modalities` (e.g., (1,0,0) for T-only).\n",
    "      - Targets are always full-field [T,Q,V] at t_out.\n",
    "      - Returns `supervised_idx` = input indices.\n",
    "    \"\"\"\n",
    "    def __init__(self, file_list, grid_meta_path, sparsity=0.02, triple_fraction=0.25,\n",
    "                 norm_stats=None, input_modalities=(1,0,0), input_region=\"union\", seed=123,train=False):\n",
    "        self.file_list = file_list\n",
    "        self.idx2latlon = load_idx2latlon(grid_meta_path)\n",
    "        self.norm_stats = norm_stats\n",
    "        self.horizons = [3, 6, 9, 12, 15, 18]\n",
    "        self.seq_len = 19\n",
    "        self.N = len(self.idx2latlon)\n",
    "\n",
    "        self.train = train\n",
    "        \n",
    "        self.fixed_idx, self.region_of_local, self.mod_masks = _assign_venn_indices(\n",
    "            n_points=self.N, sparsity=sparsity, triple_fraction=triple_fraction, seed=seed\n",
    "        )\n",
    "        self.union_idx = self.fixed_idx\n",
    "        self.triple_idx = np.where(self.mod_masks['T'] & self.mod_masks['Q'] & self.mod_masks['V'])[0]\n",
    "        self.triple_idx = np.sort(self.triple_idx)\n",
    "\n",
    "        self.input_region = input_region  # \"union\" or \"triple\"\n",
    "        self.input_idx = self.triple_idx if input_region == \"triple\" else self.union_idx\n",
    "\n",
    "        self._rng = random.Random(seed)\n",
    "        assert len(input_modalities) == 3\n",
    "        self.input_modalities = tuple(bool(int(x)) for x in input_modalities)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.file_list) - self.seq_len\n",
    "\n",
    "    def _norm(self, arr, key):\n",
    "        if self.norm_stats and key in self.norm_stats:\n",
    "            mu, sigma = self.norm_stats[key]\n",
    "            arr = (arr - mu) / sigma\n",
    "        return arr\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # load 19-step window\n",
    "        seq = [np.load(self.file_list[idx + i]) for i in range(self.seq_len)]\n",
    "        t_in = 0\n",
    "        # t_out = self._rng.choice(self.horizons)\n",
    "        if self.train:\n",
    "            t_out = 3\n",
    "        else:\n",
    "            t_out = self._rng.choice(self.horizons)\n",
    "\n",
    "        T_t  = self._norm(seq[t_in][\"state_t\"], \"T\")\n",
    "        Q_t  = self._norm(seq[t_in][\"state_q\"], \"Q\")\n",
    "        V_t  = self._norm(seq[t_in][\"state_v\"], \"V\")\n",
    "        T_tp = self._norm(seq[t_out][\"state_t\"], \"T\")\n",
    "        Q_tp = self._norm(seq[t_out][\"state_q\"], \"Q\")\n",
    "        V_tp = self._norm(seq[t_out][\"state_v\"], \"V\")\n",
    "\n",
    "        data_T, mesh_T = [], []\n",
    "        data_Q, mesh_Q = [], []\n",
    "        data_V, mesh_V = [], []\n",
    "        supervision_mask = []\n",
    "        \n",
    "        tau = float(t_out) / 18.0  # normalize to [0,1]\n",
    "\n",
    "        # --- INPUTS ONLY FROM CHOSEN REGION ---\n",
    "        for gi in self.input_idx:\n",
    "            lat, lon = self.idx2latlon[gi]\n",
    "            # modality membership from fixed Venn masks\n",
    "            bT = bool(self.mod_masks['T'][gi])\n",
    "            bQ = bool(self.mod_masks['Q'][gi])\n",
    "            bV = bool(self.mod_masks['V'][gi])\n",
    "            # for triple region, these will all be True\n",
    "            supervision_mask.append([bT, bQ, bV])\n",
    "\n",
    "            # gate by requested input_modalities\n",
    "            if self.input_modalities[0] and bT:\n",
    "                data_T.append([lon, lat, float(T_t[gi])]); mesh_T.append([lon, lat])\n",
    "            if self.input_modalities[1] and bQ:\n",
    "                data_Q.append([lon, lat, float(Q_t[gi])]); mesh_Q.append([lon, lat])\n",
    "            if self.input_modalities[2] and bV:\n",
    "                data_V.append([lon, lat, float(V_t[gi])]); mesh_V.append([lon, lat])\n",
    "\n",
    "        # --- FULL-FIELD TARGETS ---\n",
    "        data_y, mesh_y = [], []\n",
    "        for gi in range(self.N):\n",
    "            lat, lon = self.idx2latlon[gi]\n",
    "            data_y.append([float(T_tp[gi]), float(Q_tp[gi]), float(V_tp[gi])])\n",
    "            mesh_y.append([lon, lat])\n",
    "\n",
    "        used_modalities = [int(self.input_modalities[0]),\n",
    "                           int(self.input_modalities[1]),\n",
    "                           int(self.input_modalities[2])]\n",
    "\n",
    "        to_tensor = lambda x: torch.tensor(x, dtype=torch.float32)\n",
    "        return (\n",
    "            to_tensor(data_T), to_tensor(data_Q), to_tensor(data_V),\n",
    "            to_tensor(mesh_T), to_tensor(mesh_Q), to_tensor(mesh_V),\n",
    "            to_tensor(data_y), to_tensor(mesh_y),\n",
    "            torch.tensor(self.input_idx, dtype=torch.long),         \n",
    "            torch.tensor(used_modalities, dtype=torch.bool),\n",
    "            torch.tensor(supervision_mask, dtype=torch.bool),\n",
    "            torch.tensor(tau, dtype=torch.float32)  \n",
    "        )\n",
    "\n",
    "    def venn_counts(self):\n",
    "        counts = {r:0 for r in REGIONS}\n",
    "        for r in self.region_of_local:\n",
    "            counts[r] += 1\n",
    "        return counts\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1defe22f-9a97-4c4b-bd7e-495983010fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------------------------\n",
    "from torch.utils.data import DataLoader\n",
    "import torch, torch.nn as nn\n",
    "from torch.optim import AdamW\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import glob\n",
    "from torch.utils.data import random_split, DataLoader\n",
    "\n",
    "import glob, numpy as np\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "\n",
    "\n",
    "# load the norm stats, samples, and grid info (must download to train)\n",
    "norm_stats    = dict(np.load(\"norm_TQV_full.npz\", allow_pickle=True))\n",
    "file_list     = sorted(glob.glob(\"processed/**/*.npz\", recursive=True))\n",
    "grid_meta_path = \"ClimSim_high-res_grid-info.nc\"\n",
    "\n",
    "\n",
    "sparsity = 0.02  # 2% -> 432 \n",
    "\n",
    "\n",
    "dataset = ClimSimTQVForecastVennFixed(\n",
    "    file_list=file_list,\n",
    "    grid_meta_path=\"ClimSim_high-res_grid-info.nc\",\n",
    "    sparsity=0.02,\n",
    "    triple_fraction=0.25,\n",
    "    norm_stats=norm_stats,\n",
    "    input_modalities=(1,1,1),   # only TQV values\n",
    "    input_region=\"union\",     \n",
    "    seed=123,\n",
    ")\n",
    "\n",
    "\n",
    "def pretty_counts(counts):\n",
    "    order = ['T', 'Q', 'V', 'TQ', 'TV', 'QV', 'TQV']\n",
    "    return {k: counts[k] for k in order}\n",
    "\n",
    "# Per-modality counts (each modality has exactly K_mod points)\n",
    "k_T = int(dataset.mod_masks['T'].sum())\n",
    "k_Q = int(dataset.mod_masks['Q'].sum())\n",
    "k_V = int(dataset.mod_masks['V'].sum())\n",
    "\n",
    "# Triple and union sizes\n",
    "triple_cnt = int((dataset.mod_masks['T'] & dataset.mod_masks['Q'] & dataset.mod_masks['V']).sum())\n",
    "union_cnt  = int((dataset.mod_masks['T'] | dataset.mod_masks['Q'] | dataset.mod_masks['V']).sum())\n",
    "\n",
    "# Region-by-region Venn counts from the dataloader\n",
    "venn = dataset.venn_counts()\n",
    "\n",
    "print(\"=== Sensor/Region Summary ===\")\n",
    "print(f\"Per-modality (K_mod): T={k_T}, Q={k_Q}, V={k_V}\")\n",
    "print(f\"Triple region (T∩Q∩V): {triple_cnt}\")\n",
    "print(f\"Union of all modalities (unique locations): {union_cnt}\")\n",
    "print(\"Venn breakdown (exact memberships):\")\n",
    "for k, v in pretty_counts(venn).items():\n",
    "    print(f\"  {k}: {v}\")\n",
    "\n",
    "print(\"\\n=== DataLoader Input Region Check ===\")\n",
    "print(f\"dataset.input_region = {dataset.input_region!r}\")\n",
    "print(f\"len(dataset.input_idx) = {len(dataset.input_idx)}\")\n",
    "if dataset.input_region == \"union\":\n",
    "    print(f\"Matches union? {len(dataset.input_idx) == union_cnt}\")\n",
    "elif dataset.input_region == \"triple\":\n",
    "    print(f\"Matches triple? {len(dataset.input_idx) == triple_cnt}\")\n",
    "else:\n",
    "    print(\"Unknown input_region (expected 'union' or 'triple').\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e726d3-60b6-4d47-989f-21f2c33da05d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== Clean, runnable training cell====\n",
    "\n",
    "# --- Imports ---\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "from torch.optim import AdamW\n",
    "\n",
    "from itertools import cycle\n",
    "from tqdm import tqdm\n",
    "from math import log\n",
    "\n",
    "from einops import rearrange, repeat\n",
    "from cosine_annealing_warmup import CosineAnnealingWarmupRestarts\n",
    "\n",
    "\n",
    "# --- Device ---\n",
    "DEVICE = \"cuda:3\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "# ===============================================================\n",
    "# --- 1. The One True Perceiver IO Model Architecture ---\n",
    "# ===============================================================\n",
    "# helpers\n",
    "def exists(val):\n",
    "    return val is not None\n",
    "\n",
    "def default(val, d):\n",
    "    return val if exists(val) else d\n",
    "\n",
    "from functools import wraps\n",
    "def cache_fn(f):\n",
    "    cache = None\n",
    "    @wraps(f)\n",
    "    def cached_fn(*args, _cache = True, **kwargs):\n",
    "        if not _cache:\n",
    "            return f(*args, **kwargs)\n",
    "        nonlocal cache\n",
    "        if cache is not None:\n",
    "            return cache\n",
    "        cache = f(*args, **kwargs)\n",
    "        return cache\n",
    "    return cached_fn\n",
    "\n",
    "# helper classes\n",
    "class PreNorm(nn.Module):\n",
    "    def __init__(self, dim, fn, context_dim = None):\n",
    "        super().__init__()\n",
    "        self.fn = fn\n",
    "        self.norm = nn.LayerNorm(dim)\n",
    "        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None\n",
    "\n",
    "    def forward(self, x, **kwargs):\n",
    "        x = self.norm(x)\n",
    "        if exists(self.norm_context):\n",
    "            context = kwargs['context']\n",
    "            normed_context = self.norm_context(context)\n",
    "            kwargs.update(context = normed_context)\n",
    "        return self.fn(x, **kwargs)\n",
    "\n",
    "class GEGLU(nn.Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim = -1)\n",
    "        return x * F.gelu(gates)\n",
    "\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, dim, mult = 4):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(dim, dim * mult * 2),\n",
    "            GEGLU(),\n",
    "            nn.Linear(dim * mult, dim)\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):\n",
    "        super().__init__()\n",
    "        inner_dim = dim_head * heads\n",
    "        context_dim = default(context_dim, query_dim)\n",
    "        self.scale = dim_head ** -0.5\n",
    "        self.heads = heads\n",
    "        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)\n",
    "        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n",
    "        self.to_out = nn.Linear(inner_dim, query_dim)\n",
    "        self.latest_attn = None\n",
    "\n",
    "    def forward(self, x, context = None, mask = None):\n",
    "        h = self.heads\n",
    "        q = self.to_q(x)\n",
    "        context = default(context, x)\n",
    "        k, v = self.to_kv(context).chunk(2, dim = -1)\n",
    "        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))\n",
    "        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale\n",
    "        if exists(mask):\n",
    "            mask = rearrange(mask, 'b ... -> b (...)')\n",
    "            max_neg_value = -torch.finfo(sim.dtype).max\n",
    "            mask = repeat(mask, 'b j -> (b h) () j', h = h)\n",
    "            sim.masked_fill_(~mask, max_neg_value)\n",
    "        attn = sim.softmax(dim = -1)\n",
    "        self.latest_attn = attn.detach()\n",
    "        out = torch.einsum('b i j, b j d -> b i d', attn, v)\n",
    "        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)\n",
    "        return self.to_out(out)\n",
    "\n",
    "# This helper function creates the sinusoidal embeddings\n",
    "def get_sinusoidal_embeddings(n, d):\n",
    "    \"\"\"\n",
    "    Generates sinusoidal positional embeddings.\n",
    "    \n",
    "    Args:\n",
    "        n (int): The number of positions (num_latents).\n",
    "        d (int): The embedding dimension (latent_dim).\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: A tensor of shape (n, d) with sinusoidal embeddings.\n",
    "    \"\"\"\n",
    "    assert d % 2 == 0, \"latent_dim must be an even number for sinusoidal embeddings\"\n",
    "    position = torch.arange(n, dtype=torch.float).unsqueeze(1)\n",
    "    div_term = torch.exp(torch.arange(0, d, 2).float() * -(log(10000.0) / d))\n",
    "    pe = torch.zeros(n, d)\n",
    "    pe[:, 0::2] = torch.sin(position * div_term)\n",
    "    pe[:, 1::2] = torch.cos(position * div_term)\n",
    "    return pe\n",
    "\n",
    "\n",
    "class CascadedBlock(nn.Module):\n",
    "    def __init__(self, dim, n_latents, input_dim, cross_heads, cross_dim_head, self_heads, self_dim_head, residual_dim=None):\n",
    "        super().__init__()\n",
    "        self.latents = nn.Parameter(get_sinusoidal_embeddings(n_latents, dim), requires_grad=True)\n",
    "        self.cross_attn = PreNorm(dim, Attention(dim, input_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=input_dim)\n",
    "        self.self_attn = PreNorm(dim, Attention(dim, heads=self_heads, dim_head=self_dim_head))\n",
    "        self.residual_proj = nn.Linear(residual_dim, dim) if residual_dim and residual_dim != dim else None\n",
    "        self.ff = PreNorm(dim, FeedForward(dim))\n",
    "\n",
    "    def forward(self, x, context, mask=None, residual=None):\n",
    "        b = context.size(0)\n",
    "        latents = repeat(self.latents, 'n d -> b n d', b=b)\n",
    "        latents = self.cross_attn(latents, context=context, mask=mask) + latents\n",
    "        if residual is not None:\n",
    "            if self.residual_proj:\n",
    "                residual = self.residual_proj(residual)\n",
    "            latents = latents + residual\n",
    "        latents = self.self_attn(latents) + latents\n",
    "        latents = self.ff(latents) + latents\n",
    "        return latents\n",
    "\n",
    "\n",
    "class CascadedPerceiverIO(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        input_dim,\n",
    "        queries_dim,\n",
    "        logits_dim = None,\n",
    "        latent_dims=(512, 512, 512),\n",
    "        num_latents=(256, 256, 256),\n",
    "        cross_heads = 4,\n",
    "        cross_dim_head = 128,\n",
    "        self_heads = 8,\n",
    "        self_dim_head = 128,\n",
    "        decoder_ff = False,\n",
    "        \n",
    "    ):\n",
    "        super().__init__()\n",
    "        \n",
    "        assert len(latent_dims) == len(num_latents), \"latent_dims and num_latents must have same length\"\n",
    "        \n",
    "        self.latent_dims  = list(latent_dims)\n",
    "        self.num_latents  = list(num_latents)\n",
    "\n",
    "        self.input_proj = nn.Sequential(\n",
    "                nn.Linear(3, 128),\n",
    "                nn.GELU(),\n",
    "                nn.Linear(128, 128)\n",
    "            ).to(DEVICE)\n",
    "        \n",
    "        self.input_proj_T = nn.Sequential(\n",
    "                nn.Linear(3, 128),\n",
    "                nn.GELU(),\n",
    "                nn.Linear(128, 128)\n",
    "            ).to(DEVICE)\n",
    "        \n",
    "        self.input_proj_Q = nn.Sequential(\n",
    "                nn.Linear(3, 128),\n",
    "                nn.GELU(),\n",
    "                nn.Linear(128, 128)\n",
    "            ).to(DEVICE)\n",
    "        \n",
    "        self.input_proj_V = nn.Sequential(\n",
    "                nn.Linear(3, 128),\n",
    "                nn.GELU(),\n",
    "                nn.Linear(128, 128)\n",
    "            ).to(DEVICE)\n",
    "        \n",
    "        self.projection_matrix = nn.Parameter(torch.randn(4, 128) / np.sqrt(4)).to(DEVICE)\n",
    "\n",
    "        # --- 2. Per-Modality Encoder Blocks ---\n",
    "        def make_encoder_blocks():\n",
    "            blocks = nn.ModuleList()\n",
    "            prev_dim = None\n",
    "            for dim, n_latents in zip(latent_dims, num_latents):\n",
    "                blocks.append(CascadedBlock(\n",
    "                    dim=dim,\n",
    "                    n_latents=n_latents,\n",
    "                    input_dim=input_dim,\n",
    "                    cross_heads=cross_heads,\n",
    "                    cross_dim_head=cross_dim_head,\n",
    "                    self_heads=self_heads,\n",
    "                    self_dim_head=self_dim_head,\n",
    "                    residual_dim=prev_dim\n",
    "                ))\n",
    "                prev_dim = dim\n",
    "            return blocks\n",
    "\n",
    "        # Cross-attn: Q/V → T\n",
    "        self.cross_T_from_Q = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.cross_T_from_V = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        # === Q fusion ===\n",
    "        self.cross_Q_from_T = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.cross_Q_from_V = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        # === V fusion ===\n",
    "        self.cross_V_from_T = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.cross_V_from_Q = PreNorm(\n",
    "            latent_dims[-1],\n",
    "            Attention(\n",
    "                query_dim=latent_dims[-1],\n",
    "                context_dim=latent_dims[-1],\n",
    "                heads=cross_heads,\n",
    "                dim_head=cross_dim_head\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.encoder_blocks_T = make_encoder_blocks()\n",
    "        self.encoder_blocks_Q = make_encoder_blocks()\n",
    "        self.encoder_blocks_V = make_encoder_blocks()\n",
    "        \n",
    "        self.sa_queries_T = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))\n",
    "        self.sa_queries_Q = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))\n",
    "        self.sa_queries_V = PreNorm(queries_dim, Attention(queries_dim, heads=4, dim_head=64))\n",
    "        \n",
    "        final_latent_dim = latent_dims[-1]\n",
    "        self.global_proj_T = nn.Linear(final_latent_dim, input_dim)\n",
    "        self.global_proj_Q = nn.Linear(final_latent_dim, input_dim)\n",
    "        self.global_proj_V = nn.Linear(final_latent_dim, input_dim)\n",
    "        \n",
    "        self.global2latent_proj_T = nn.ModuleList([\n",
    "            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))\n",
    "        ])\n",
    "        self.global2latent_proj_Q = nn.ModuleList([\n",
    "            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))\n",
    "        ])\n",
    "        self.global2latent_proj_V = nn.ModuleList([\n",
    "            nn.Linear(final_latent_dim, num_latents[i] * latent_dims[i]) for i in range(len(latent_dims))\n",
    "        ])\n",
    "\n",
    "        # Cascaded encoder blocks (generic list) - kept for checkpoint compatibility\n",
    "        self.encoder_blocks = nn.ModuleList()\n",
    "        prev_dim = None\n",
    "        for dim, n_latents in zip(latent_dims, num_latents):\n",
    "            block = CascadedBlock(\n",
    "                dim=dim,\n",
    "                n_latents=n_latents,\n",
    "                input_dim=input_dim,\n",
    "                cross_heads=cross_heads,\n",
    "                cross_dim_head=cross_dim_head,\n",
    "                self_heads=self_heads,\n",
    "                self_dim_head=self_dim_head,\n",
    "                residual_dim=prev_dim\n",
    "            )\n",
    "            self.encoder_blocks.append(block)\n",
    "            prev_dim = dim\n",
    "\n",
    "        # Decoder\n",
    "        final_latent_dim = latent_dims[-1]\n",
    "        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)\n",
    "        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None\n",
    "        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()\n",
    "        \n",
    "        self.decoder_cross_attn_T = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)\n",
    "        self.decoder_ff_T = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None\n",
    "        self.to_logits_T = nn.Linear(queries_dim, 1)\n",
    "        \n",
    "        self.decoder_cross_attn_Q = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)\n",
    "        self.decoder_ff_Q = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None\n",
    "        self.to_logits_Q = nn.Linear(queries_dim, 1)\n",
    "\n",
    "        self.decoder_cross_attn_V = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)\n",
    "        self.decoder_ff_V = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None\n",
    "        self.to_logits_V = nn.Linear(queries_dim, 1)\n",
    "\n",
    "        self.self_attn_blocks = nn.Sequential(*[\n",
    "            nn.Sequential(\n",
    "                PreNorm(latent_dims[-1], Attention(latent_dims[-1], heads=self_heads, dim_head=self_dim_head)),\n",
    "                PreNorm(latent_dims[-1], FeedForward(latent_dims[-1]))\n",
    "            )\n",
    "            for _ in range(3)\n",
    "        ])\n",
    "\n",
    "    def forward(self, x_T, x_Q, x_V, queries, used_modalities):\n",
    "        def residual_from_global(global_latent, proj_layer, n_latents_k, dim_k):\n",
    "            if global_latent is None:\n",
    "                return None\n",
    "            G_pool = global_latent.mean(dim=1)                      # [B, Dg]\n",
    "            R = proj_layer(G_pool).view(G_pool.size(0), n_latents_k, dim_k)\n",
    "            return R\n",
    "\n",
    "        global_latent = None\n",
    "        num_stages = 3\n",
    "\n",
    "        for stage_idx in range(num_stages):\n",
    "            stage_latents = []\n",
    "            nL_k = self.num_latents[stage_idx]\n",
    "            d_k  = self.latent_dims[stage_idx]\n",
    "\n",
    "            # --- T modality ---\n",
    "            if used_modalities[0] and x_T is not None:\n",
    "                R_T = residual_from_global(global_latent, self.global2latent_proj_T[stage_idx], nL_k, d_k)\n",
    "                latent_T = self.encoder_blocks_T[stage_idx](x=None, context=x_T, residual=R_T)\n",
    "                stage_latents.append(latent_T)\n",
    "            else:\n",
    "                latent_T = None\n",
    "\n",
    "            # --- Q modality ---\n",
    "            if used_modalities[1] and x_Q is not None:\n",
    "                R_Q = residual_from_global(global_latent, self.global2latent_proj_Q[stage_idx], nL_k, d_k)\n",
    "                latent_Q = self.encoder_blocks_Q[stage_idx](x=None, context=x_Q, residual=R_Q)\n",
    "                stage_latents.append(latent_Q)\n",
    "            else:\n",
    "                latent_Q = None\n",
    "\n",
    "            # --- V modality ---\n",
    "            if used_modalities[2] and x_V is not None:\n",
    "                R_V = residual_from_global(global_latent, self.global2latent_proj_V[stage_idx], nL_k, d_k)\n",
    "                latent_V = self.encoder_blocks_V[stage_idx](x=None, context=x_V, residual=R_V)\n",
    "                stage_latents.append(latent_V)\n",
    "            else:\n",
    "                latent_V = None\n",
    "\n",
    "            if not stage_latents:\n",
    "                raise ValueError(\"No modalities present in this batch.\")\n",
    "\n",
    "            # === Fuse present modality latents into new global ===\n",
    "            fused_latent = torch.cat(stage_latents, dim=1)\n",
    "            for sa_block in self.self_attn_blocks:\n",
    "                fused_latent = sa_block[0](fused_latent) + fused_latent\n",
    "                fused_latent = sa_block[1](fused_latent) + fused_latent\n",
    "            global_latent = fused_latent  # pass to next stage\n",
    "\n",
    "        # === Prepare queries ===\n",
    "        if queries.ndim == 2:\n",
    "            queries = repeat(queries, 'n d -> b n d', b=global_latent.size(0))\n",
    "\n",
    "        # === Decoder: cross-attention ===\n",
    "        def decode_branch_with_query(cross_attn, ff, head):\n",
    "            q = queries\n",
    "            x = cross_attn(q, context=global_latent)\n",
    "            x = x + q\n",
    "            if ff:\n",
    "                x = x + ff(x)\n",
    "            return head(x)\n",
    "\n",
    "        T_out = decode_branch_with_query(self.decoder_cross_attn_T, self.decoder_ff_T, self.to_logits_T)\n",
    "        Q_out = decode_branch_with_query(self.decoder_cross_attn_Q, self.decoder_ff_Q, self.to_logits_Q)\n",
    "        V_out = decode_branch_with_query(self.decoder_cross_attn_V, self.decoder_ff_V, self.to_logits_V)\n",
    "\n",
    "        return T_out, Q_out, V_out\n",
    "\n",
    "\n",
    "class GaussianFourierFeatures(nn.Module):\n",
    "    def __init__(self, in_features, mapping_size, scale=15.0):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.mapping_size = mapping_size\n",
    "        self.register_buffer('B', torch.randn((in_features, mapping_size)) * scale)\n",
    "\n",
    "    def forward(self, coords):\n",
    "        projections = coords @ self.B\n",
    "        fourier_feats = torch.cat([torch.sin(projections), torch.cos(projections)], dim=-1)\n",
    "        return fourier_feats\n",
    "\n",
    "\n",
    "# -----------------------------------------------\n",
    "\n",
    "# Same as before\n",
    "norm_stats = dict(np.load(\"norm_TQV_full.npz\", allow_pickle=True))\n",
    "file_list  = sorted(glob.glob(\"processed/**/*.npz\", recursive=True))\n",
    "grid_meta_path = \"ClimSim_high-res_grid-info.nc\"\n",
    "\n",
    "sparsity = 0.02  # 2% -> 432\n",
    "\n",
    "dataset = ClimSimTQVForecastVennFixed(\n",
    "    file_list=file_list,\n",
    "    grid_meta_path=grid_meta_path,\n",
    "    sparsity=sparsity,\n",
    "    triple_fraction=0.25,\n",
    "    norm_stats=norm_stats,\n",
    "    input_modalities=(1,1,1),   # only TQV values\n",
    "    input_region=\"union\",\n",
    "    seed=123,\n",
    ")\n",
    "\n",
    "train_len = 9000\n",
    "assert len(dataset) > train_len\n",
    "train_T = Subset(dataset, range(0, train_len))\n",
    "val_T   = Subset(dataset, range(train_len, len(dataset)))\n",
    "\n",
    "train_loader_T = DataLoader(train_T, batch_size=8, shuffle=True)\n",
    "val_loader_T   = DataLoader(val_T,   batch_size=1, shuffle=False)\n",
    "\n",
    "pos_enc  = GaussianFourierFeatures(2, 32).to(DEVICE)\n",
    "time_enc = GaussianFourierFeatures(1, 16, scale=10.0).to(DEVICE)\n",
    "\n",
    "model = CascadedPerceiverIO(\n",
    "    input_dim   = 192,\n",
    "    queries_dim = 96,\n",
    "    logits_dim  = None,\n",
    "    latent_dims = (128,128,128),\n",
    "    num_latents = (128,128,128),\n",
    "    decoder_ff  = True,\n",
    ").to(DEVICE)\n",
    "\n",
    "device = DEVICE\n",
    "max_lr = 8e-5\n",
    "min_lr = 8e-6\n",
    "warmup_steps = int(0.1 * 10000)\n",
    "weight_decay = 1e-4\n",
    "\n",
    "# --------- hyper-params ----------\n",
    "TOTAL_ITERS  = 100000\n",
    "PRINT_EVERY  = 100\n",
    "VAL_EVERY    = 100\n",
    "\n",
    "opt = AdamW(\n",
    "    model.parameters(),\n",
    "    lr=max_lr,\n",
    "    betas=(0.9, 0.999),\n",
    "    weight_decay=weight_decay\n",
    ")\n",
    "\n",
    "scheduler = CosineAnnealingWarmupRestarts(\n",
    "    opt,\n",
    "    first_cycle_steps=TOTAL_ITERS,\n",
    "    max_lr=max_lr,\n",
    "    min_lr=min_lr,\n",
    "    warmup_steps=warmup_steps\n",
    ")\n",
    "\n",
    "def save_checkpoint(model, pos_enc, time_enc, optimizer, epoch, val_loss, save_path):\n",
    "    \"\"\"\n",
    "    Save model, positional encoder, time encoder, optimizer state, and metadata.\n",
    "    \"\"\"\n",
    "    dirpath = os.path.dirname(save_path)\n",
    "    if dirpath:\n",
    "        os.makedirs(dirpath, exist_ok=True)\n",
    "    state = {\n",
    "        \"epoch\": epoch,\n",
    "        \"model_state\": model.state_dict(),\n",
    "        \"pos_enc_state\": pos_enc.state_dict(),\n",
    "        \"time_enc_state\": time_enc.state_dict(),\n",
    "        \"optimizer_state\": optimizer.state_dict(),\n",
    "        \"val_loss\": val_loss,\n",
    "    }\n",
    "    torch.save(state, save_path)\n",
    "    print(f\"Saved model to {save_path} (val_loss={val_loss:.4f})\")\n",
    "\n",
    "def load_checkpoint(model, pos_enc, time_enc, optimizer, load_path, device=None):\n",
    "    \"\"\"\n",
    "    Load model, positional encoder, time encoder, and optimizer state.\n",
    "    \"\"\"\n",
    "    if device is None:\n",
    "        device = DEVICE\n",
    "    checkpoint = torch.load(load_path, map_location=device)\n",
    "    model.load_state_dict(checkpoint[\"model_state\"])\n",
    "    pos_enc.load_state_dict(checkpoint[\"pos_enc_state\"])\n",
    "    time_enc.load_state_dict(checkpoint[\"time_enc_state\"])\n",
    "    if optimizer is not None and \"optimizer_state\" in checkpoint:\n",
    "        optimizer.load_state_dict(checkpoint[\"optimizer_state\"])\n",
    "    print(f\"Loaded model from {load_path} (epoch={checkpoint.get('epoch')}, val_loss={checkpoint.get('val_loss'):.4f})\")\n",
    "    return checkpoint\n",
    "\n",
    "def get_lr(optimizer):\n",
    "    return optimizer.param_groups[0][\"lr\"]\n",
    "\n",
    "@torch.no_grad()\n",
    "def extract_T(outs):\n",
    "    \"\"\"\n",
    "    Normalize model outputs to a [B, N, 1] tensor for T.\n",
    "    Supports:\n",
    "      - tuple/list: (pred_T, pred_Q, pred_V)\n",
    "      - dict: {'T':..., 'Q':..., 'V':...}\n",
    "      - tensor: [B, N, 1] or [B, N] (assumed to be T)\n",
    "    \"\"\"\n",
    "    if isinstance(outs, (list, tuple)):\n",
    "        assert len(outs) >= 1, \"Empty outputs!\"\n",
    "        return outs[0]\n",
    "    if isinstance(outs, dict):\n",
    "        assert 'T' in outs, \"Dict output missing key 'T'.\"\n",
    "        return outs['T']\n",
    "    return outs\n",
    "\n",
    "@torch.no_grad()\n",
    "def run_validation_fullfield_T(model, val_loader, pos_enc_fn, time_enc):\n",
    "    \"\"\"\n",
    "    Quick T-only validation on a single batch (assumes batch_size==1 in val_loader).\n",
    "    Returns (val_loss_scalar, val_T_mse_scalar).\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    batch = next(iter(val_loader))\n",
    "    (\n",
    "        data_T, data_Q, data_V,\n",
    "        mesh_T, mesh_Q, mesh_V,\n",
    "        data_y, mesh_y,\n",
    "        supervised_idx, used_modalities, supervision_mask, tau\n",
    "    ) = batch\n",
    "\n",
    "    data_T, mesh_T = data_T.to(DEVICE), mesh_T.to(DEVICE)\n",
    "    data_Q, mesh_Q = data_Q.to(DEVICE), mesh_Q.to(DEVICE)\n",
    "    data_V, mesh_V = data_V.to(DEVICE), mesh_V.to(DEVICE)\n",
    "\n",
    "    data_y = data_y.to(DEVICE).squeeze(0)\n",
    "    mesh_y = mesh_y.to(DEVICE).squeeze(0)\n",
    "    tau    = tau.to(DEVICE).view(1)\n",
    "    used_modalities = tuple(bool(x) for x in used_modalities.squeeze(0).tolist())\n",
    "\n",
    "    x_T = x_Q = x_V = None\n",
    "    if data_T.numel() > 0:\n",
    "        x_T = torch.cat([model.input_proj_T(data_T), pos_enc_fn(mesh_T)], dim=-1)\n",
    "    if data_Q.numel() > 0:\n",
    "        x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_fn(mesh_Q)], dim=-1)\n",
    "    if data_V.numel() > 0:\n",
    "        x_V = torch.cat([model.input_proj_V(data_V), pos_enc_fn(mesh_V)], dim=-1)\n",
    "\n",
    "    q_spatial = pos_enc_fn(mesh_y).unsqueeze(0)\n",
    "    tfeat = time_enc(tau[:, None])\n",
    "    tfeat = tfeat[:, None, :].expand(-1, q_spatial.size(1), -1)\n",
    "    queries_full = torch.cat([q_spatial, tfeat], dim=-1)\n",
    "\n",
    "    outs = model(x_T=x_T, x_Q=x_Q, x_V=x_V, queries=queries_full, used_modalities=used_modalities)\n",
    "    pred_T = extract_T(outs)\n",
    "    if pred_T.ndim == 3 and pred_T.size(-1) == 1:\n",
    "        pred_T = pred_T.squeeze(-1)\n",
    "    pred_T = pred_T.squeeze(0)\n",
    "\n",
    "    tgt_T = data_y[:, 0]\n",
    "    lT = F.mse_loss(pred_T, tgt_T)\n",
    "    return float(lT), float(lT)\n",
    "\n",
    "@torch.no_grad()\n",
    "def run_validation_fullfield_T_total(model, val_loader, pos_enc_fn, time_enc, max_batches=None):\n",
    "    \"\"\"\n",
    "    T-only validation over the entire val_loader (or first max_batches).\n",
    "    Returns mean T-MSE across evaluated batches.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    total, count = 0.0, 0\n",
    "    for b_idx, batch in enumerate(val_loader):\n",
    "        if (max_batches is not None) and (b_idx >= max_batches):\n",
    "            break\n",
    "\n",
    "        (data_T, data_Q, data_V,\n",
    "         mesh_T, mesh_Q, mesh_V,\n",
    "         data_y, mesh_y,\n",
    "         supervised_idx, used_modalities, supervision_mask, tau) = batch\n",
    "\n",
    "        data_T, mesh_T = data_T.to(DEVICE), mesh_T.to(DEVICE)\n",
    "        data_Q, mesh_Q = data_Q.to(DEVICE), mesh_Q.to(DEVICE)\n",
    "        data_V, mesh_V = data_V.to(DEVICE), mesh_V.to(DEVICE)\n",
    "        data_y = data_y.to(DEVICE).squeeze(0)\n",
    "        mesh_y = mesh_y.to(DEVICE).squeeze(0)\n",
    "        tau    = tau.to(DEVICE).view(1)\n",
    "        used_modalities = tuple(bool(x) for x in used_modalities.squeeze(0).tolist())\n",
    "\n",
    "        x_T = torch.cat([model.input_proj_T(data_T), pos_enc_fn(mesh_T)], dim=-1) if data_T.numel() > 0 else None\n",
    "        x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_fn(mesh_Q)], dim=-1) if data_Q.numel() > 0 else None\n",
    "        x_V = torch.cat([model.input_proj_V(data_V), pos_enc_fn(mesh_V)], dim=-1) if data_V.numel() > 0 else None\n",
    "\n",
    "        q_spatial = pos_enc_fn(mesh_y).unsqueeze(0)\n",
    "        tfeat = time_enc(tau[:, None])\n",
    "        tfeat = tfeat[:, None, :].expand(-1, q_spatial.size(1), -1)\n",
    "        queries_full = torch.cat([q_spatial, tfeat], dim=-1)\n",
    "\n",
    "        outs = model(x_T=x_T, x_Q=x_Q, x_V=x_V, queries=queries_full, used_modalities=used_modalities)\n",
    "        pred_T = extract_T(outs)\n",
    "        if pred_T.ndim == 3 and pred_T.size(-1) == 1:\n",
    "            pred_T = pred_T.squeeze(-1)\n",
    "        pred_T = pred_T.squeeze(0)\n",
    "\n",
    "        tgt_T = data_y[:, 0]\n",
    "        lT = F.mse_loss(pred_T, tgt_T).item()\n",
    "        total += lT\n",
    "        count += 1\n",
    "\n",
    "    return total / max(1, count)\n",
    "\n",
    "def pos_enc_batched(pos_enc_fn, coords):\n",
    "    # coords: [..., 2]\n",
    "    leading = coords.shape[:-1]\n",
    "    flat = coords.reshape(-1, coords.shape[-1])        # [*, 2]\n",
    "    enc  = pos_enc_fn(flat)                            # [*, Dpos]\n",
    "    return enc.view(*leading, enc.shape[-1])           # [..., Dpos]\n",
    "\n",
    "train_iter = cycle(train_loader_T)\n",
    "\n",
    "running_loss = 0.0\n",
    "running_mse_T = 0.0   # scalar running avg for T only\n",
    "\n",
    "best_val_loss = float(\"inf\")\n",
    "save_path = \"ClimSim_checkpoints/best_model.pt\"\n",
    "\n",
    "for it in tqdm(range(1, TOTAL_ITERS + 1)):\n",
    "    model.train()\n",
    "\n",
    "    (data_T, data_Q, data_V,\n",
    "     mesh_T, mesh_Q, mesh_V,\n",
    "     data_y, mesh_y,\n",
    "     supervised_idx, used_modalities, supervision_mask, tau) = next(train_iter)\n",
    "\n",
    "    # ---- move (NO squeeze; keep batch dim) ----\n",
    "    data_T = data_T.to(DEVICE)       # [B, S_T, 3]\n",
    "    mesh_T = mesh_T.to(DEVICE)       # [B, S_T, 2]\n",
    "    data_Q = data_Q.to(DEVICE)       # [B, S_Q, 3]\n",
    "    mesh_Q = mesh_Q.to(DEVICE)       # [B, S_Q, 2]\n",
    "    data_V = data_V.to(DEVICE)       # [B, S_V, 3]\n",
    "    mesh_V = mesh_V.to(DEVICE)       # [B, S_V, 2]\n",
    "    data_y = data_y.to(DEVICE)       # [B, N, 3]\n",
    "    mesh_y = mesh_y.to(DEVICE)       # [B, N, 2]\n",
    "    tau    = tau.to(DEVICE).view(-1)\n",
    "\n",
    "    # Multimodal inputs → T target\n",
    "    used_modalities_bits = (True, True, True)\n",
    "\n",
    "    # ---- inputs: project + concat positional enc (batched) ----\n",
    "    x_T = torch.cat([model.input_proj_T(data_T), pos_enc_batched(pos_enc, mesh_T)], dim=-1) if data_T.numel() > 0 else None\n",
    "    x_Q = torch.cat([model.input_proj_Q(data_Q), pos_enc_batched(pos_enc, mesh_Q)], dim=-1) if data_Q.numel() > 0 else None\n",
    "    x_V = torch.cat([model.input_proj_V(data_V), pos_enc_batched(pos_enc, mesh_V)], dim=-1) if data_V.numel() > 0 else None\n",
    "\n",
    "    # ---- FULL-GRID queries (batched) ----\n",
    "    tfeat = time_enc(tau[:, None])                                         # [B, Dtime]\n",
    "    tfeat_expanded = tfeat[:, None, :].expand(-1, mesh_y.shape[1], -1)     # [B, N, Dtime]\n",
    "    queries_spatial = pos_enc_batched(pos_enc, mesh_y)                      # [B, N, Dpos]\n",
    "    queries_full = torch.cat([queries_spatial, tfeat_expanded], dim=-1)     # [B, N, Dpos+Dtime]\n",
    "\n",
    "    # ---- forward ----\n",
    "    outs = model(\n",
    "        x_T=x_T, x_Q=x_Q, x_V=x_V,\n",
    "        queries=queries_full,\n",
    "        used_modalities=used_modalities_bits\n",
    "    )\n",
    "    pred_T = extract_T(outs)           # [B,N,1] or [B,N]\n",
    "    if pred_T.ndim == 3 and pred_T.size(-1) == 1:\n",
    "        pred_T = pred_T.squeeze(-1)    # [B,N]\n",
    "\n",
    "    # ---- targets ----\n",
    "    tgt_T = data_y[..., 0]             # [B, N]\n",
    "\n",
    "    # ---- full-field MSE over the batch (T only) ----\n",
    "    lT = F.mse_loss(pred_T, tgt_T)     # scalar across B*N\n",
    "    loss = lT\n",
    "\n",
    "    opt.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "    opt.step()\n",
    "    scheduler.step()\n",
    "\n",
    "    # ---- logging ----\n",
    "    with torch.no_grad():\n",
    "        running_mse_T += lT.item()\n",
    "        running_loss  += float(loss)\n",
    "\n",
    "    if it % PRINT_EVERY == 0:\n",
    "        mT = running_mse_T / PRINT_EVERY\n",
    "        print(f\"  [TQV->T|Full|Batched]  lr={get_lr(opt):.2e}  B={data_T.shape[0]}\")\n",
    "        print(f\"[Iter {it}] Total Loss: {running_loss / PRINT_EVERY:.6f}\")\n",
    "        print(f\"  Full-field MSE  T: {mT:.6f}\")\n",
    "        running_loss = 0.0\n",
    "        running_mse_T = 0.0\n",
    "\n",
    "    if it % VAL_EVERY == 0:\n",
    "        val_loss, val_mT = run_validation_fullfield_T(model, val_loader_T, pos_enc, time_enc)\n",
    "        print(f\"[VAL] T-only Loss (full-field): {val_loss:.6f}   (T MSE: {val_mT:.6f})\")\n",
    "\n",
    "    if it % (VAL_EVERY * 5) == 0:\n",
    "        val_loss_full = run_validation_fullfield_T_total(model, val_loader_T, pos_enc, time_enc)\n",
    "        print(f\"[VAL-TOTAL] Mean T MSE over val: {val_loss_full:.6f}\")\n",
    "        if val_loss_full < best_val_loss:\n",
    "            best_val_loss = val_loss_full\n",
    "            save_checkpoint(model, pos_enc, time_enc, opt, it, val_loss_full, save_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a710fc4-d56a-48f2-98d8-b503e6538181",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- imports & norm stats (same as you have) ---\n",
    "# ==== Build T-only dataset & loaders ====\n",
    "import glob, numpy as np\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "\n",
    "norm_stats    = dict(np.load(\"norm_TQV_full.npz\", allow_pickle=True))\n",
    "file_list     = sorted(glob.glob(\"processed/**/*.npz\", recursive=True))\n",
    "grid_meta_path = \"ClimSim_high-res_grid-info.nc\"\n",
    "\n",
    "\n",
    "sparsity = 0.02  # 2% -> 432 T points\n",
    "\n",
    "\n",
    "dataset = ClimSimTQVForecastVennFixed(\n",
    "    file_list=file_list,\n",
    "    grid_meta_path=\"ClimSim_high-res_grid-info.nc\",\n",
    "    sparsity=0.02,\n",
    "    triple_fraction=0.25,\n",
    "    norm_stats=norm_stats,\n",
    "    input_modalities=(1,1,1),   \n",
    "    input_region=\"union\",      \n",
    "    seed=123,\n",
    ")\n",
    "\n",
    "\n",
    "train_len = 9000\n",
    "assert len(dataset) > train_len\n",
    "train_T = Subset(dataset, range(0, train_len))\n",
    "val_T   = Subset(dataset, range(train_len, len(dataset)))\n",
    "\n",
    "\n",
    "train_loader_T = DataLoader(\n",
    "    train_T, batch_size=8, shuffle=True,\n",
    ")\n",
    "val_loader_T = DataLoader(\n",
    "    val_T, batch_size=1, shuffle=False,\n",
    ")\n",
    "\n",
    "print(len(train_T))  # number of samples in the train subset\n",
    "print(len(val_T)) \n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23ab76e-dbb7-46d5-8e65-82e08cc08878",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Robust loader for ClimSim norm stats (T/Q/V) from NPZ files\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "TQV_KEYS = (\"T\", \"Q\", \"V\")\n",
    "\n",
    "def _to_float1(x, fallback=0.0):\n",
    "    if x is None: return float(fallback)\n",
    "    arr = np.asarray(x)\n",
    "    if arr.size == 0: return float(fallback)\n",
    "    return float(arr.reshape(-1)[0])\n",
    "\n",
    "def _extract_mu_sd_from_entry(v):\n",
    "    \"\"\"\n",
    "    Accept:\n",
    "      - dict-like with 'mean'/'std' or 'mu'/'sigma'\n",
    "      - array-like [mu, std]\n",
    "      - scalar-like (invalid → None)\n",
    "    Return (mu, sd) or None\n",
    "    \"\"\"\n",
    "    # dict-like (possibly 0-d object array)\n",
    "    if isinstance(v, dict):\n",
    "        mu = _to_float1(v.get(\"mean\", v.get(\"mu\", None)), 0.0)\n",
    "        sd = max(_to_float1(v.get(\"std\",  v.get(\"sigma\", None)), 1.0), 1e-6)\n",
    "        return mu, sd\n",
    "\n",
    "    if isinstance(v, np.ndarray) and v.dtype == object:\n",
    "        # 0-d or 1-element object array that holds a dict or pair\n",
    "        if v.shape == () or (v.size == 1 and v.ndim <= 1):\n",
    "            inner = v.reshape(()).item()\n",
    "            if isinstance(inner, dict):\n",
    "                mu = _to_float1(inner.get(\"mean\", inner.get(\"mu\", None)), 0.0)\n",
    "                sd = max(_to_float1(inner.get(\"std\",  inner.get(\"sigma\", None)), 1.0), 1e-6)\n",
    "                return mu, sd\n",
    "            # fallthrough: maybe it's [mu, std] inside object — try array path below\n",
    "            v = np.asarray(inner)\n",
    "\n",
    "    # array-like [mu, std]\n",
    "    arr = np.asarray(v)\n",
    "    if arr.size >= 2:\n",
    "        mu = _to_float1(arr[0], 0.0)\n",
    "        sd = max(_to_float1(arr[1], 1.0), 1e-6)\n",
    "        return mu, sd\n",
    "\n",
    "    return None\n",
    "\n",
    "def load_tqv_norm_stats(npz_path: str):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      - stats_tuples: { 'T': (mu, sd), 'Q': (mu, sd), 'V': (mu, sd) }\n",
    "      - Normalizer with .mu/.sd dicts and .norm/.denorm methods (NumPy or torch)\n",
    "    Tries these layouts:\n",
    "      A) Keys 'T','Q','V' with dicts or [mu, sd] arrays\n",
    "      B) Split keys: 'T_mu'/'T_std', 'Q_mu'/'Q_std', 'V_mu'/'V_std'\n",
    "      C) Packed under 'stats' → {'T':..., 'Q':..., 'V':...}\n",
    "    \"\"\"\n",
    "    z = dict(np.load(npz_path, allow_pickle=True))\n",
    "    stats = {}\n",
    "\n",
    "    # Case A: direct per-key entries\n",
    "    if all(k in z for k in TQV_KEYS):\n",
    "        for k in TQV_KEYS:\n",
    "            got = _extract_mu_sd_from_entry(z[k])\n",
    "            if got is None:\n",
    "                raise ValueError(f\"Key '{k}' present but not parseable as (mu, sd). \"\n",
    "                                 f\"Type={type(z[k])}, shape={np.shape(z[k])}\")\n",
    "            stats[k] = got\n",
    "    else:\n",
    "        # Case B: split keys\n",
    "        ok = True\n",
    "        tmp = {}\n",
    "        for k in TQV_KEYS:\n",
    "            mu = z.get(f\"{k}_mu\", z.get(f\"mu_{k}\", None))\n",
    "            sd = z.get(f\"{k}_std\", z.get(f\"std_{k}\", None))\n",
    "            if (mu is None) or (sd is None):\n",
    "                ok = False\n",
    "                break\n",
    "            tmp[k] = ( _to_float1(mu, 0.0), max(_to_float1(sd, 1.0), 1e-6) )\n",
    "        if ok:\n",
    "            stats = tmp\n",
    "        else:\n",
    "            # Case C: packed under a 'stats' key\n",
    "            packed = z.get(\"stats\", None)\n",
    "            if packed is None:\n",
    "                raise ValueError(\n",
    "                    \"Could not find T/Q/V norm stats in NPZ. \"\n",
    "                    f\"Available keys: {sorted(z.keys())}\"\n",
    "                )\n",
    "            if isinstance(packed, np.ndarray) and packed.dtype == object:\n",
    "                packed = packed.reshape(()).item()\n",
    "            if not isinstance(packed, dict):\n",
    "                raise ValueError(\"Key 'stats' exists but is not a dict-like container.\")\n",
    "            for k in TQV_KEYS:\n",
    "                if k not in packed:\n",
    "                    raise ValueError(f\"'stats' missing key '{k}'.\")\n",
    "                got = _extract_mu_sd_from_entry(packed[k])\n",
    "                if got is None:\n",
    "                    raise ValueError(f\"'stats[{k}]' not parseable as (mu, sd).\")\n",
    "                stats[k] = got\n",
    "\n",
    "    class TQVNormalizer:\n",
    "        def __init__(self, stats_tuples):\n",
    "            self.mu = {k: float(stats_tuples[k][0]) for k in TQV_KEYS}\n",
    "            self.sd = {k: float(stats_tuples[k][1]) for k in TQV_KEYS}\n",
    "        def norm(self, key, x):\n",
    "            mu, sd = self.mu[key], self.sd[key]\n",
    "            if isinstance(x, torch.Tensor):\n",
    "                return (x - mu) / sd\n",
    "            return (np.asarray(x) - mu) / sd\n",
    "        def denorm(self, key, x):\n",
    "            mu, sd = self.mu[key], self.sd[key]\n",
    "            if isinstance(x, torch.Tensor):\n",
    "                return x * sd + mu\n",
    "            return np.asarray(x) * sd + mu\n",
    "\n",
    "    return stats, TQVNormalizer(stats)\n",
    "\n",
    "# ==== Load & print ====\n",
    "norm_stats_tuples, TQV_NORM = load_tqv_norm_stats(\"norm_TQV_full.npz\")\n",
    "print(\"ClimSim normalization (μ, σ):\")\n",
    "for k in TQV_KEYS:\n",
    "    mu, sd = norm_stats_tuples[k]\n",
    "    print(f\"  {k}: mu={mu:.6f}, std={sd:.6f}\")\n",
    "\n",
    "# Example helper: denorm [B, N, 3] tensor with channels [T,Q,V]\n",
    "def denorm_tqv_batch(y_norm: torch.Tensor):\n",
    "    T = TQV_NORM.denorm(\"T\", y_norm[..., 0])\n",
    "    Q = TQV_NORM.denorm(\"Q\", y_norm[..., 1])\n",
    "    V = TQV_NORM.denorm(\"V\", y_norm[..., 2])\n",
    "    return {\"T\": T, \"Q\": Q, \"V\": V}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e46d24-a9ba-40c3-afa3-29a6c4096d05",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ddim",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
