#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Shallow-water-style J(z) ablation with a G-flexibility interaction (MeshFT pointwise reduction)

Goal
----
Test the interaction:
  - conservative coupling:   J fixed vs J(z) (same incidence-wired skew pattern; only gains vary)
  - metric-side capacity:    G diagonal vs G full (3x3 PD)

Truth model (for data generation)
--------------------------------
We generate trajectories from a structure-matched pointwise-reduction system
  dz/dt = J_true(z) e_true(z),  with e_true(z)=G_true z,
where:
  - G_true is diagonal and state-independent (co-energy is integrable).
  - J_true(z) has signed-incidence wiring with orientation-even positive gains cx(z), cy(z),
    spatially varying and depending on [h, mx^2, my^2], with mean=1 gauge-fix per sample.

Models (2x2)
------------
  J-mode:
    - "fixed":    cx=cy=1
    - "state":    cx(z),cy(z) predicted from [h, mx^2, my^2], mean=1 gauge-fix
  G-mode:
    - "diag":     diagonal PD per-cell
    - "full":     full 3x3 PD per-cell via Cholesky

We run the 4 combinations:
  (1) J_fixed  + G_diag
  (2) J_fixed  + G_full
  (3) J_state  + G_diag
  (4) J_state  + G_full

Data regimes
-----------
We keep low/high n_traj knobs and run all modes.

Metrics
-------
- best_val_kstep_mse: k-step MSE on validation pairs (same unroll k as training)
- rollout_vrmse_mean: VRMSE over long rollout vs RK4 truth (mean over a few val trajectories)

Integral errors (truth-aligned; recommended):
- mass_err_T / momx_err_T / momy_err_T / energy_err_T:
    absolute error at final time: |I_pred(T)-I_true(T)|
- mass_err_rmse / ...:
    time-series RMSE: sqrt(mean_t (I_pred(t)-I_true(t))^2)

Structure sanity (optional):
- mass_drift_pred:   |mass_pred(T)-mass_pred(0)|
- energy_drift_pred: |E_pred(T)-E_pred(0)| under truth-consistent quadratic energy

Gain identification:
- gain_mse_cx/gain_mse_cy: MSE of predicted cx,cy vs truth cx,cy on initial states (val trajs)

Outputs
-------
outdir/
  ntraj{N}/seed{S}/mode_{MODE}/summary.json
  aggregate.json  (mean/std over seeds per (n_traj, mode))

Run example
-----------
python shallow_jdep_ablation_gflex.py --device cuda --seeds 0,1,2 --ntraj_low 4 --ntraj_high 64
"""

import os, math, json, random, argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ----------------------------
# Repro
# ----------------------------

def set_seed(seed: int, deterministic: bool = False):
    os.environ["PYTHONHASHSEED"] = str(int(seed))
    random.seed(int(seed))
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))
    torch.cuda.manual_seed_all(int(seed))
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        try:
            torch.use_deterministic_algorithms(True)
        except Exception:
            pass

def make_cpu_generator(seed: int) -> torch.Generator:
    g = torch.Generator(device="cpu")
    g.manual_seed(int(seed))
    return g

def seed_worker(worker_id: int):
    ws = torch.initial_seed() % (2**32)
    np.random.seed(ws)
    random.seed(ws)


# ----------------------------
# Periodic finite differences (cell/edge operators)
# ----------------------------

def roll_x(f: torch.Tensor, shift: int) -> torch.Tensor:
    return torch.roll(f, shifts=shift, dims=-1)

def roll_y(f: torch.Tensor, shift: int) -> torch.Tensor:
    return torch.roll(f, shifts=shift, dims=-2)

# cell -> edge averages (A)
def avg_x(f: torch.Tensor) -> torch.Tensor:
    return 0.5 * (f + roll_x(f, -1))

def avg_y(f: torch.Tensor) -> torch.Tensor:
    return 0.5 * (f + roll_y(f, -1))

# edge -> cell adjoint averages (A^T)
def avg_x_T(g: torch.Tensor) -> torch.Tensor:
    return 0.5 * (g + roll_x(g, 1))

def avg_y_T(g: torch.Tensor) -> torch.Tensor:
    return 0.5 * (g + roll_y(g, 1))

# edge -> cell divergence (D)
def Dx(fx_edge: torch.Tensor) -> torch.Tensor:
    return fx_edge - roll_x(fx_edge, 1)

def Dy(fy_edge: torch.Tensor) -> torch.Tensor:
    return fy_edge - roll_y(fy_edge, 1)

def div_edge_flux(fx_edge: torch.Tensor, fy_edge: torch.Tensor) -> torch.Tensor:
    return Dx(fx_edge) + Dy(fy_edge)

# cell -> edge negative adjoint gradients (D^T)
def Dx_T(p_cell: torch.Tensor) -> torch.Tensor:
    return p_cell - roll_x(p_cell, -1)

def Dy_T(p_cell: torch.Tensor) -> torch.Tensor:
    return p_cell - roll_y(p_cell, -1)


# ----------------------------
# Incidence-wired skew operator J(c): apply to co-energy e
# ----------------------------

def apply_J_pointwise_skew(e: torch.Tensor, cx_cell: torch.Tensor, cy_cell: torch.Tensor) -> torch.Tensor:
    """
    e: (B,3,H,W) for (e_h, e_mx, e_my)
    cx_cell, cy_cell: (B,1,H,W) positive (orientation-even).
    """
    e_h  = e[:, 0:1]
    e_mx = e[:, 1:2]
    e_my = e[:, 2:3]

    cx_edge = avg_x(cx_cell)
    cy_edge = avg_y(cy_cell)

    # hdot = -div( cx * avg(e_mx), cy * avg(e_my) )
    fx_h = cx_edge * avg_x(e_mx)
    fy_h = cy_edge * avg_y(e_my)
    hdot = -div_edge_flux(fx_h, fy_h)

    # mxdot = A^T( cx * D^T(e_h) )
    mxdot = avg_x_T(cx_edge * Dx_T(e_h))

    # mydot = A^T( cy * D^T(e_h) )
    mydot = avg_y_T(cy_edge * Dy_T(e_h))

    return torch.cat([hdot, mxdot, mydot], dim=1)


# ----------------------------
# Truth dynamics: dz/dt = J_true(z) e_true(z),  e_true = G_true z
# ----------------------------

def true_coenergy_linear(z: torch.Tensor, g0: float = 1.0) -> torch.Tensor:
    """
    e = G_true z with G_true = diag(g0, 1, 1) (state-independent).
    z: (3,H,W) or (B,3,H,W)
    """
    if z.dim() == 3:
        z = z.unsqueeze(0)
    e_h  = g0 * z[:, 0:1]
    e_mx = z[:, 1:2]
    e_my = z[:, 2:3]
    return torch.cat([e_h, e_mx, e_my], dim=1)

def true_cxcy(z: torch.Tensor, mode: str, c_min: float = 0.2, c_max: float = 5.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Return cx(z), cy(z) for truth. mode in {"fixed","const","state"}.
    - fixed: cx=cy=1
    - const: cx,cy constant != 1
    - state: spatially varying from orientation-even features, mean=1 gauge-fix
    """
    if z.dim() == 3:
        z = z.unsqueeze(0)
    B, _, H, W = z.shape

    if mode == "fixed":
        cx = torch.ones((B,1,H,W), device=z.device, dtype=z.dtype)
        cy = torch.ones((B,1,H,W), device=z.device, dtype=z.dtype)
        return cx, cy

    if mode == "const":
        cx = torch.full((B,1,H,W), 1.8, device=z.device, dtype=z.dtype)
        cy = torch.full((B,1,H,W), 0.7, device=z.device, dtype=z.dtype)
        return cx, cy

    assert mode == "state"
    h  = z[:, 0:1]
    mx = z[:, 1:2]
    my = z[:, 2:3]
    z_even = torch.cat([h, mx*mx, my*my], dim=1)

    # deliberately nonlinear, orientation-even; depends on h and mx^2,my^2
    raw = (
        2.0 * torch.tanh(3.0 * (z_even[:,0:1] - 1.0))
        + 0.6 * torch.tanh(2.0 * (z_even[:,1:2] - z_even[:,1:2].mean(dim=(-2,-1), keepdim=True)))
        - 0.6 * torch.tanh(2.0 * (z_even[:,2:3] - z_even[:,2:3].mean(dim=(-2,-1), keepdim=True)))
    )

    cx = c_min + (c_max - c_min) * torch.sigmoid(raw)
    cy = c_min + (c_max - c_min) * torch.sigmoid(-raw)

    # gauge-fix mean=1 per sample
    cx = cx / cx.mean(dim=(-2,-1), keepdim=True).clamp_min(1e-6)
    cy = cy / cy.mean(dim=(-2,-1), keepdim=True).clamp_min(1e-6)
    return cx, cy

def true_vector_field_Jdep(z: torch.Tensor, g0: float, j_mode: str) -> torch.Tensor:
    """
    dz/dt = J(z) e(z) with incidence wiring and z-dependent gains.
    z: (3,H,W) or (B,3,H,W)
    """
    if z.dim() == 3:
        zB = z.unsqueeze(0)
    else:
        zB = z
    e = true_coenergy_linear(zB, g0=g0)
    cx, cy = true_cxcy(zB, mode=j_mode)
    Fz = apply_J_pointwise_skew(e, cx_cell=cx, cy_cell=cy)
    return Fz[0] if z.dim() == 3 else Fz


def integrate_true(
    z0: torch.Tensor,
    dt: float,
    n_steps: int,
    integrator: str = "rk4",
    clamp_h_min: float = 1e-3,
    truth_j_mode: str = "state",
    truth_g0: float = 1.0,
) -> torch.Tensor:
    """
    Truth rollout with RK4 for dz/dt = J_true(z) e_true(z).
    z0: (3,H,W) CPU float
    """
    assert z0.dim() == 3 and z0.shape[0] == 3
    z = z0.clone()
    traj = [z.clone()]
    for _ in range(int(n_steps)):
        if integrator == "rk4":
            k1 = true_vector_field_Jdep(z, truth_g0, truth_j_mode)
            k2 = true_vector_field_Jdep(z + 0.5*dt*k1, truth_g0, truth_j_mode)
            k3 = true_vector_field_Jdep(z + 0.5*dt*k2, truth_g0, truth_j_mode)
            k4 = true_vector_field_Jdep(z + dt*k3, truth_g0, truth_j_mode)
            z = z + (dt/6.0) * (k1 + 2*k2 + 2*k3 + k4)
        else:
            z = z + dt * true_vector_field_Jdep(z, truth_g0, truth_j_mode)

        z[0] = torch.clamp(z[0], min=float(clamp_h_min))
        traj.append(z.clone())
    return torch.stack(traj, dim=0)


# ----------------------------
# ICs
# ----------------------------

def lowpass_random_field(H: int, W: int, device: torch.device, cutoff: float = 0.25) -> torch.Tensor:
    noise = torch.randn(H, W, device=device)
    Fk = torch.fft.rfft2(noise)
    ky = torch.fft.fftfreq(H, d=1.0).to(device)
    kx = torch.fft.rfftfreq(W, d=1.0).to(device)
    Ky, Kx = torch.meshgrid(ky, kx, indexing="ij")
    K = torch.sqrt(Kx**2 + Ky**2)
    mask = (K < cutoff).float()
    field = torch.fft.irfft2(Fk * mask, s=(H, W))
    return field / (field.std() + 1e-6)

def make_wave_ic(
    H: int, W: int, device: torch.device,
    amp_wave: float = 0.12,
    amp_noise: float = 0.01,
    vel_scale: float = 0.02,
) -> torch.Tensor:
    ys, xs = torch.meshgrid(
        torch.linspace(-math.pi, math.pi, H, device=device),
        torch.linspace(-math.pi, math.pi, W, device=device),
        indexing="ij",
    )
    kx = int(np.random.randint(1, 4))
    ky = int(np.random.randint(1, 4))
    px = float(np.random.uniform(0, 2*np.pi))
    py = float(np.random.uniform(0, 2*np.pi))

    base = torch.sin(kx * xs + px) * torch.sin(ky * ys + py)
    if np.random.rand() < 0.5:
        kx2 = int(np.random.randint(1, 4))
        ky2 = int(np.random.randint(1, 4))
        px2 = float(np.random.uniform(0, 2*np.pi))
        py2 = float(np.random.uniform(0, 2*np.pi))
        base = base + 0.5 * torch.sin(kx2 * xs + px2) * torch.sin(ky2 * ys + py2)

    base = base / (base.abs().max() + 1e-6)
    smooth = lowpass_random_field(H, W, device, cutoff=0.25)

    h0 = 1.0 + amp_wave * base + amp_noise * smooth
    h0 = torch.clamp(h0, min=1e-3)

    ux0 = vel_scale * lowpass_random_field(H, W, device, cutoff=0.25)
    uy0 = vel_scale * lowpass_random_field(H, W, device, cutoff=0.25)
    mx0 = h0 * ux0
    my0 = h0 * uy0
    return torch.stack([h0, mx0, my0], dim=0)


# ----------------------------
# Dataset (trajectory-level split)
# ----------------------------

class ShallowWaterKStepPairsDataset(Dataset):
    """
    Pre-generate truth trajectories with the J(z) truth model.
    Returns k-step pairs (z_t, z_{t+k}).
    """
    def __init__(
        self,
        n_trajectories: int,
        traj_len: int,
        H: int,
        W: int,
        dt: float,
        k_step: int,
        truth_j_mode: str,
        truth_g0: float,
        device_gen: str = "cpu",
        integrator: str = "rk4",
        clamp_h_min: float = 1e-3,
        store_dtype: torch.dtype = torch.float16,
    ):
        super().__init__()
        assert k_step >= 1
        assert traj_len > k_step + 1
        assert truth_j_mode in ("fixed","const","state")

        self.n_traj = int(n_trajectories)
        self.traj_len = int(traj_len)
        self.H, self.W = int(H), int(W)
        self.dt = float(dt)
        self.k = int(k_step)
        self.truth_j_mode = str(truth_j_mode)
        self.truth_g0 = float(truth_g0)
        self.device_gen = torch.device(device_gen)
        self.integrator = str(integrator)
        self.clamp_h_min = float(clamp_h_min)
        self.store_dtype = store_dtype

        self.z = self._generate().to("cpu", dtype=self.store_dtype)  # (N,T,3,H,W)

    def _make_ic(self) -> torch.Tensor:
        return make_wave_ic(self.H, self.W, self.device_gen)

    def _generate(self) -> torch.Tensor:
        zs = []
        for _ in range(self.n_traj):
            z0 = self._make_ic()
            traj = integrate_true(
                z0=z0,
                dt=self.dt,
                n_steps=self.traj_len - 1,
                integrator=self.integrator,
                clamp_h_min=self.clamp_h_min,
                truth_j_mode=self.truth_j_mode,
                truth_g0=self.truth_g0,
            )
            zs.append(traj)
        return torch.stack(zs, dim=0)

    def __len__(self):
        return self.n_traj * (self.traj_len - 1 - self.k)

    def __getitem__(self, idx: int):
        per = (self.traj_len - 1 - self.k)
        traj_idx = idx // per
        t_idx = idx % per
        z_t = self.z[traj_idx, t_idx].to(torch.float32)
        z_tk = self.z[traj_idx, t_idx + self.k].to(torch.float32)
        return {"z_t": z_t, "z_tk": z_tk, "traj_idx": int(traj_idx), "t": int(t_idx)}

class TrajSubset(Dataset):
    def __init__(self, base: ShallowWaterKStepPairsDataset, traj_indices: List[int]):
        super().__init__()
        self.base = base
        self.traj_indices = [int(i) for i in traj_indices]
        self.per_traj = (base.traj_len - 1 - base.k)

    def __len__(self):
        return len(self.traj_indices) * self.per_traj

    def __getitem__(self, idx: int):
        local_tr = idx // self.per_traj
        local_t = idx % self.per_traj
        tid = self.traj_indices[local_tr]
        z_t = self.base.z[tid, local_t].to(torch.float32)
        z_tk = self.base.z[tid, local_t + self.base.k].to(torch.float32)
        return {"z_t": z_t, "z_tk": z_tk, "traj_idx": int(tid), "t": int(local_t)}


# ----------------------------
# Truth-consistent integrals (recommended diagnostics)
# ----------------------------

@torch.no_grad()
def compute_integrals(traj: torch.Tensor, g0_energy: float) -> Dict[str, np.ndarray]:
    """
    Truth-consistent integrals for dz/dt = J(z) e with e = G_true z, G_true=diag(g0,1,1):
      mass      = sum h
      momx/momy = sum mx / my  (not necessarily conserved if cx,cy vary spatially)
      energy    = sum ( 0.5*g0*h^2 + 0.5*mx^2 + 0.5*my^2 )
    traj: (T,3,H,W) CPU float
    """
    h = traj[:, 0]
    mx = traj[:, 1]
    my = traj[:, 2]
    mass = h.sum(dim=(1,2))
    momx = mx.sum(dim=(1,2))
    momy = my.sum(dim=(1,2))
    E = (0.5*float(g0_energy)*h*h + 0.5*mx*mx + 0.5*my*my).sum(dim=(1,2))
    return {
        "mass": mass.cpu().numpy(),
        "momx": momx.cpu().numpy(),
        "momy": momy.cpu().numpy(),
        "energy": E.cpu().numpy(),
    }


# ----------------------------
# Model: J-mode x G-mode (2x2)
# ----------------------------

class GNet(nn.Module):
    """
    Metric-side PD parameterization from orientation-even features.
    g_mode:
      - "diag":  outputs 3 positive diagonals (per-cell)
      - "full":  outputs 6 Cholesky params (per-cell) -> full 3x3 PD
    """
    def __init__(self, g_mode: str, hidden_ch: int = 64, n_layers: int = 3):
        super().__init__()
        assert g_mode in ("diag", "full")
        self.g_mode = str(g_mode)
        out_ch = 3 if self.g_mode == "diag" else 6

        layers = []
        ch = 3
        for _ in range(int(n_layers)):
            layers += [nn.Conv2d(ch, hidden_ch, 3, padding=1, padding_mode="circular"), nn.SiLU()]
            ch = hidden_ch
        self.body = nn.Sequential(*layers)
        self.out = nn.Conv2d(ch, out_ch, 1)

    def forward(self, z_even: torch.Tensor) -> torch.Tensor:
        return self.out(self.body(z_even))

def build_pd_3x3_from_cholesky(params6: torch.Tensor, eps: float = 1e-4, gamma: float = 1e-4) -> torch.Tensor:
    l11 = F.softplus(params6[:, 0]) + eps
    l21 = params6[:, 1]
    l22 = F.softplus(params6[:, 2]) + eps
    l31 = params6[:, 3]
    l32 = params6[:, 4]
    l33 = F.softplus(params6[:, 5]) + eps

    zeros = torch.zeros_like(l11)
    row0 = torch.stack([l11, zeros, zeros], dim=1)
    row1 = torch.stack([l21, l22, zeros], dim=1)
    row2 = torch.stack([l31, l32, l33], dim=1)
    Lmat = torch.stack([row0, row1, row2], dim=1)  # (B,3,3,H,W)

    G = torch.einsum("bikHW,bjkHW->bijHW", Lmat, Lmat)
    eye = torch.eye(3, device=params6.device, dtype=params6.dtype).view(1,3,3,1,1)
    return G + gamma * eye

def build_pd_3x3_diag(params3: torch.Tensor, eps: float = 1e-4, gamma: float = 1e-4) -> torch.Tensor:
    d0 = F.softplus(params3[:, 0]) + eps
    d1 = F.softplus(params3[:, 1]) + eps
    d2 = F.softplus(params3[:, 2]) + eps
    B, _, H, W = params3.shape
    G = torch.zeros(B, 3, 3, H, W, device=params3.device, dtype=params3.dtype)
    G[:, 0, 0] = d0 + gamma
    G[:, 1, 1] = d1 + gamma
    G[:, 2, 2] = d2 + gamma
    return G

class CNet(nn.Module):
    """
    Predict cx,cy from orientation-even features.
    """
    def __init__(self, hidden_ch: int = 32, n_layers: int = 2):
        super().__init__()
        layers = []
        ch = 3
        for _ in range(int(n_layers)):
            layers += [nn.Conv2d(ch, hidden_ch, 3, padding=1, padding_mode="circular"), nn.SiLU()]
            ch = hidden_ch
        self.body = nn.Sequential(*layers)
        self.out = nn.Conv2d(ch, 2, 1)

    def forward(self, z_even: torch.Tensor) -> torch.Tensor:
        return self.out(self.body(z_even))

class MeshFTPointwiseJGNet(nn.Module):
    """
    Heun step for z=(h,mx,my):
      z_{t+1} = z_t + 0.5*dt*(f(z_t)+f(z_t+dt*f(z_t)))
      f(z) = J(z) e(z),  e=G(z) z

    J-mode:
      - "fixed": cx=cy=1
      - "state": cx(z),cy(z) from CNet(z_even), mean=1 gauge-fix per sample

    G-mode:
      - "diag":  diagonal PD per-cell
      - "full":  full 3x3 PD per-cell (Cholesky)
    """
    def __init__(
        self,
        dt: float,
        j_mode: str,
        g_mode: str,
        hidden_g: int = 64,
        layers_g: int = 3,
        hidden_c: int = 32,
        layers_c: int = 2,
        eps: float = 1e-4,
        gamma: float = 1e-4,
        c_min: float = 0.05,
        unitmean_c: bool = True,
    ):
        super().__init__()
        assert j_mode in ("fixed", "state")
        assert g_mode in ("diag", "full")
        self.dt = float(dt)
        self.j_mode = str(j_mode)
        self.g_mode = str(g_mode)
        self.eps = float(eps)
        self.gamma = float(gamma)
        self.c_min = float(c_min)
        self.unitmean_c = bool(unitmean_c)

        self.g_net = GNet(g_mode=self.g_mode, hidden_ch=int(hidden_g), n_layers=int(layers_g))
        self.c_net: Optional[CNet] = None
        if self.j_mode == "state":
            self.c_net = CNet(hidden_ch=int(hidden_c), n_layers=int(layers_c))

    def _z_even(self, z: torch.Tensor) -> torch.Tensor:
        h  = z[:, 0:1]
        mx = z[:, 1:2]
        my = z[:, 2:3]
        return torch.cat([h, mx*mx, my*my], dim=1)

    def _build_G(self, z: torch.Tensor) -> torch.Tensor:
        z_even = self._z_even(z)
        params = self.g_net(z_even)
        if self.g_mode == "full":
            return build_pd_3x3_from_cholesky(params, eps=self.eps, gamma=self.gamma)
        return build_pd_3x3_diag(params, eps=self.eps, gamma=self.gamma)

    def _build_c(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, _, H, W = z.shape
        if self.j_mode == "fixed":
            cx = torch.ones((B,1,H,W), device=z.device, dtype=z.dtype)
            cy = torch.ones((B,1,H,W), device=z.device, dtype=z.dtype)
            return cx, cy

        assert self.c_net is not None
        z_even = self._z_even(z)
        raw = self.c_net(z_even)  # (B,2,H,W)
        cx = F.softplus(raw[:, 0:1]) + self.c_min
        cy = F.softplus(raw[:, 1:2]) + self.c_min
        if self.unitmean_c:
            cx = cx / cx.mean(dim=(-2,-1), keepdim=True).clamp_min(1e-6)
            cy = cy / cy.mean(dim=(-2,-1), keepdim=True).clamp_min(1e-6)
        return cx, cy

    def vector_field(self, z: torch.Tensor) -> torch.Tensor:
        G = self._build_G(z)
        cx, cy = self._build_c(z)
        e = torch.einsum("bijHW,bjHW->biHW", G, z)
        return apply_J_pointwise_skew(e, cx_cell=cx, cy_cell=cy)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        k1 = self.vector_field(z)
        z1 = z + self.dt * k1
        k2 = self.vector_field(z1)
        return z + 0.5 * self.dt * (k1 + k2)


# ----------------------------
# Training / evaluation helpers
# ----------------------------

def clamp_depth(z: torch.Tensor, h_min: float) -> torch.Tensor:
    h = torch.clamp(z[:, 0:1], min=float(h_min))
    return torch.cat([h, z[:, 1:3]], dim=1)

@torch.no_grad()
def kstep_eval_mse(model: nn.Module, loader: DataLoader, k: int, h_min: float, device: torch.device) -> float:
    model.eval()
    total = 0.0
    count = 0
    for batch in loader:
        z = batch["z_t"].to(device)
        z_tk = batch["z_tk"].to(device)
        for _ in range(int(k)):
            z = model(z)
            z = clamp_depth(z, h_min)
        total += F.mse_loss(z, z_tk, reduction="sum").item()
        count += z_tk.numel()
    return total / max(1, count)

def train_one_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    k_unroll: int,
    h_min: float,
    n_epochs: int,
    lr: float,
    weight_decay: float,
    clip_grad: float,
) -> Tuple[nn.Module, float]:
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=float(lr), weight_decay=float(weight_decay))

    best = float("inf")
    best_state = None

    for ep in range(1, int(n_epochs)+1):
        model.train()
        for batch in train_loader:
            z = batch["z_t"].to(device)
            z_tk = batch["z_tk"].to(device)

            z_pred = z
            for _ in range(int(k_unroll)):
                z_pred = model(z_pred)
                z_pred = clamp_depth(z_pred, h_min)

            loss = F.mse_loss(z_pred, z_tk)

            opt.zero_grad()
            loss.backward()
            if clip_grad and clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(clip_grad))
            opt.step()

        val = kstep_eval_mse(model, val_loader, k_unroll, h_min, device)
        if val < best:
            best = val
            best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
        print(f"[train] ep={ep:03d} val_kstep_mse={val:.3e}")

    if best_state is not None:
        model.load_state_dict(best_state)
    return model, float(best)

def _rmse(x: np.ndarray) -> float:
    return float(np.sqrt(np.mean(x*x))) if x.size > 0 else float("nan")

@torch.no_grad()
def rollout_and_metrics(
    model: nn.Module,
    z0: torch.Tensor,           # (3,H,W) CPU float
    dt: float,
    rollout_steps: int,
    h_min: float,
    device: torch.device,
    truth_j_mode: str,
    truth_g0: float,
) -> Dict[str, float]:
    """
    Rollout model and compare to RK4 truth.
    Primary diagnostics are truth-aligned integral errors (not drift of possibly non-invariant quantities).
    """
    # truth
    z_true = integrate_true(
        z0=z0,
        dt=dt,
        n_steps=rollout_steps,
        integrator="rk4",
        clamp_h_min=h_min,
        truth_j_mode=truth_j_mode,
        truth_g0=truth_g0,
    )

    # pred
    model.eval()
    z = z0.unsqueeze(0).to(device)
    traj = [z[0].detach().cpu().clone()]
    for _ in range(int(rollout_steps)):
        z = model(z)
        z = clamp_depth(z, h_min)
        traj.append(z[0].detach().cpu().clone())
    z_pred = torch.stack(traj, dim=0)

    # state error
    diff = z_pred - z_true
    mse = float((diff**2).mean().item())
    var_true = float(((z_true - z_true.mean())**2).mean().item())
    vrmse = float(math.sqrt(mse / (var_true + 1e-8)))

    # integrals (truth-consistent)
    I_true = compute_integrals(z_true, g0_energy=truth_g0)
    I_pred = compute_integrals(z_pred, g0_energy=truth_g0)

    out: Dict[str, float] = {
        "rollout_vrmse": vrmse,
        "rollout_mse": mse,
    }

    # final-time integral errors (truth-aligned)
    for k in ["mass","momx","momy","energy"]:
        out[f"{k}_err_T"] = float(abs(I_pred[k][-1] - I_true[k][-1]))
        out[f"{k}_err_rmse"] = _rmse(I_pred[k] - I_true[k])

    # optional “self drift” of pred (structure sanity, not a correctness metric)
    out["mass_drift_pred"] = float(abs(I_pred["mass"][-1] - I_pred["mass"][0]))
    out["energy_drift_pred"] = float(abs(I_pred["energy"][-1] - I_pred["energy"][0]))
    return out

@torch.no_grad()
def gain_recon_mse_on_z0(
    model: MeshFTPointwiseJGNet,
    z0_list: List[torch.Tensor],
    device: torch.device,
    truth_j_mode: str,
) -> Dict[str, float]:
    """
    Compare predicted cx,cy vs truth cx,cy on initial states.
    """
    cx_mses = []
    cy_mses = []
    for z0 in z0_list:
        zB = z0.unsqueeze(0).to(device)
        cx_p, cy_p = model._build_c(zB)  # (1,1,H,W)
        cx_p = cx_p.detach().cpu()
        cy_p = cy_p.detach().cpu()

        z0B_cpu = z0.unsqueeze(0)
        cx_t, cy_t = true_cxcy(z0B_cpu, mode=truth_j_mode)
        cx_t = cx_t.cpu()
        cy_t = cy_t.cpu()

        cx_mses.append(float(((cx_p - cx_t)**2).mean().item()))
        cy_mses.append(float(((cy_p - cy_t)**2).mean().item()))

    return {
        "gain_mse_cx": float(np.mean(cx_mses)) if cx_mses else float("nan"),
        "gain_mse_cy": float(np.mean(cy_mses)) if cy_mses else float("nan"),
    }


# ----------------------------
# Experiment driver
# ----------------------------

@dataclass
class ExpConfig:
    device: str = "cuda"
    deterministic: bool = False

    # data
    H: int = 64
    W: int = 64
    dt: float = 0.02
    traj_len: int = 260
    k_train: int = 8
    h_min: float = 1e-3
    val_frac: float = 0.2

    # truth
    truth_j_mode: str = "state"     # {"fixed","const","state"}
    truth_g0: float = 1.0           # energy weight for h in G_true=diag(g0,1,1)

    # training
    batch_size: int = 6
    n_epochs: int = 12
    lr: float = 1e-3
    weight_decay: float = 0.0
    clip_grad: float = 1.0

    # model capacity
    hidden_g: int = 64
    layers_g: int = 3
    hidden_c: int = 32
    layers_c: int = 2
    c_min: float = 0.05
    unitmean_c: bool = True

    # eval
    rollout_steps: int = 600
    n_eval_trajs: int = 3

    # outputs
    outdir: str = "out_shallow_jdep_gflex_ablation"


def build_dataset_split(cfg: ExpConfig, seed: int, n_traj: int) -> Tuple[ShallowWaterKStepPairsDataset, Dataset, Dataset, List[int], List[int]]:
    set_seed(seed, deterministic=cfg.deterministic)
    ds = ShallowWaterKStepPairsDataset(
        n_trajectories=n_traj,
        traj_len=cfg.traj_len,
        H=cfg.H,
        W=cfg.W,
        dt=cfg.dt,
        k_step=cfg.k_train,
        truth_j_mode=cfg.truth_j_mode,
        truth_g0=cfg.truth_g0,
        device_gen="cpu",
        integrator="rk4",
        clamp_h_min=cfg.h_min,
        store_dtype=torch.float16,
    )

    g_split = make_cpu_generator(seed + 20000)
    perm = torch.randperm(ds.n_traj, generator=g_split).tolist()
    n_val = max(1, int(math.ceil(cfg.val_frac * ds.n_traj)))
    val_ids = perm[:n_val]
    train_ids = perm[n_val:] if len(perm[n_val:]) > 0 else perm[:1]

    train_ds = TrajSubset(ds, train_ids)
    val_ds = TrajSubset(ds, val_ids)
    return ds, train_ds, val_ds, train_ids, val_ids

def build_loaders(train_ds: Dataset, val_ds: Dataset, cfg: ExpConfig, seed: int) -> Tuple[DataLoader, DataLoader]:
    g_loader = make_cpu_generator(seed)
    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=0,
        generator=g_loader,
        worker_init_fn=seed_worker,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=0,
        worker_init_fn=seed_worker,
    )
    return train_loader, val_loader

def mode_tag(j_mode: str, g_mode: str) -> str:
    return f"J_{j_mode}__G_{g_mode}"

def run_one_setting(cfg: ExpConfig, seed: int, n_traj: int, j_mode: str, g_mode: str) -> Dict[str, float]:
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")

    ds, train_ds, val_ds, train_ids, val_ids = build_dataset_split(cfg, seed=seed, n_traj=n_traj)

    loader_seed = seed + 30000
    train_loader, val_loader = build_loaders(train_ds, val_ds, cfg, seed=loader_seed)

    set_seed(seed + 1000 + (hash(j_mode) % 1000) + 17*(hash(g_mode) % 1000), deterministic=cfg.deterministic)
    model = MeshFTPointwiseJGNet(
        dt=cfg.dt,
        j_mode=j_mode,
        g_mode=g_mode,
        hidden_g=cfg.hidden_g,
        layers_g=cfg.layers_g,
        hidden_c=cfg.hidden_c,
        layers_c=cfg.layers_c,
        c_min=cfg.c_min,
        unitmean_c=cfg.unitmean_c,
    )

    model, best_val = train_one_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        k_unroll=cfg.k_train,
        h_min=cfg.h_min,
        n_epochs=cfg.n_epochs,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        clip_grad=cfg.clip_grad,
    )

    use_ids = val_ids[:min(cfg.n_eval_trajs, len(val_ids))]
    z0_list = [ds.z[int(tid), 0].to(torch.float32) for tid in use_ids]

    mets = []
    for z0 in z0_list:
        mets.append(
            rollout_and_metrics(
                model=model,
                z0=z0,
                dt=cfg.dt,
                rollout_steps=cfg.rollout_steps,
                h_min=cfg.h_min,
                device=device,
                truth_j_mode=cfg.truth_j_mode,
                truth_g0=cfg.truth_g0,
            )
        )

    def _mean(key: str) -> float:
        return float(np.mean([float(x[key]) for x in mets])) if mets else float("nan")

    gain_m = gain_recon_mse_on_z0(model, z0_list=z0_list, device=device, truth_j_mode=cfg.truth_j_mode)

    out = {
        "seed": int(seed),
        "n_traj": int(n_traj),
        "j_mode": str(j_mode),
        "g_mode": str(g_mode),
        "mode": mode_tag(j_mode, g_mode),
        "train_traj": int(len(train_ids)),
        "val_traj": int(len(val_ids)),
        "best_val_kstep_mse": float(best_val),

        "rollout_vrmse_mean": _mean("rollout_vrmse"),
        "rollout_mse_mean": _mean("rollout_mse"),

        # truth-aligned integral errors (recommended)
        "mass_err_T_mean": _mean("mass_err_T"),
        "momx_err_T_mean": _mean("momx_err_T"),
        "momy_err_T_mean": _mean("momy_err_T"),
        "energy_err_T_mean": _mean("energy_err_T"),

        "mass_err_rmse_mean": _mean("mass_err_rmse"),
        "momx_err_rmse_mean": _mean("momx_err_rmse"),
        "momy_err_rmse_mean": _mean("momy_err_rmse"),
        "energy_err_rmse_mean": _mean("energy_err_rmse"),

        # self drift sanity (not correctness)
        "mass_drift_pred_mean": _mean("mass_drift_pred"),
        "energy_drift_pred_mean": _mean("energy_drift_pred"),

        # gain id
        "gain_mse_cx": float(gain_m["gain_mse_cx"]),
        "gain_mse_cy": float(gain_m["gain_mse_cy"]),
    }

    outdir = Path(cfg.outdir) / f"ntraj{n_traj}" / f"seed{seed:04d}" / f"mode_{out['mode']}"
    outdir.mkdir(parents=True, exist_ok=True)
    with open(outdir / "summary.json", "w") as f:
        json.dump(out, f, indent=2)
    torch.save({"state_dict": model.state_dict(), "config": cfg.__dict__, "result": out}, outdir / "model.pt")

    print(
        f"[done] n_traj={n_traj} seed={seed} mode={out['mode']} "
        f"val_k={out['best_val_kstep_mse']:.3e} vrmse={out['rollout_vrmse_mean']:.3e} "
        f"mass_err_T={out['mass_err_T_mean']:.3e} energy_err_T={out['energy_err_T_mean']:.3e} "
        f"gain_mse(cx)={out['gain_mse_cx']:.3e}"
    )
    return out

def aggregate_results(rows: List[Dict[str, float]]) -> Dict:
    key2rows: Dict[Tuple[int,str], List[Dict[str,float]]] = {}
    for r in rows:
        key = (int(r["n_traj"]), str(r["mode"]))
        key2rows.setdefault(key, []).append(r)

    def _ms(xs: List[float]) -> Dict[str,float]:
        xs = [float(x) for x in xs]
        return {
            "mean": float(np.mean(xs)),
            "std": float(np.std(xs, ddof=1)) if len(xs) >= 2 else 0.0,
            "n": int(len(xs)),
        }

    metrics = [
        "best_val_kstep_mse",
        "rollout_vrmse_mean",
        "mass_err_T_mean","momx_err_T_mean","momy_err_T_mean","energy_err_T_mean",
        "mass_err_rmse_mean","momx_err_rmse_mean","momy_err_rmse_mean","energy_err_rmse_mean",
        "mass_drift_pred_mean","energy_drift_pred_mean",
        "gain_mse_cx","gain_mse_cy",
    ]

    out = {"groups": []}
    for (n_traj, mode), rs in sorted(key2rows.items(), key=lambda x: (x[0][0], x[0][1])):
        g = {"n_traj": int(n_traj), "mode": str(mode)}
        for m in metrics:
            g[m] = _ms([r[m] for r in rs])
        out["groups"].append(g)
    return out

def parse_int_list(s: str) -> List[int]:
    s = (s or "").strip()
    if not s:
        return [0]
    return [int(x.strip()) for x in s.split(",") if x.strip()]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--deterministic", type=int, default=0)

    ap.add_argument("--seeds", type=str, default="0,1,2")

    ap.add_argument("--ntraj_low", type=int, default=4)
    ap.add_argument("--ntraj_high", type=int, default=64)

    ap.add_argument("--H", type=int, default=64)
    ap.add_argument("--W", type=int, default=64)
    ap.add_argument("--dt", type=float, default=0.02)
    ap.add_argument("--traj_len", type=int, default=260)
    ap.add_argument("--k_train", type=int, default=8)
    ap.add_argument("--rollout_steps", type=int, default=600)
    ap.add_argument("--n_eval_trajs", type=int, default=3)
    ap.add_argument("--h_min", type=float, default=1e-3)

    # truth knobs
    ap.add_argument("--truth_j_mode", type=str, default="state", choices=["fixed","const","state"])
    ap.add_argument("--truth_g0", type=float, default=1.0)

    # training
    ap.add_argument("--batch_size", type=int, default=6)
    ap.add_argument("--n_epochs", type=int, default=12)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--weight_decay", type=float, default=0.0)
    ap.add_argument("--clip_grad", type=float, default=1.0)

    # model capacity
    ap.add_argument("--hidden_g", type=int, default=64)
    ap.add_argument("--layers_g", type=int, default=3)
    ap.add_argument("--hidden_c", type=int, default=32)
    ap.add_argument("--layers_c", type=int, default=2)
    ap.add_argument("--c_min", type=float, default=0.05)
    ap.add_argument("--unitmean_c", type=int, default=1)

    ap.add_argument("--outdir", type=str, default="out_shallow_jdep_gflex_ablation")

    args = ap.parse_args()

    cfg = ExpConfig(
        device=str(args.device),
        deterministic=bool(args.deterministic),

        H=int(args.H),
        W=int(args.W),
        dt=float(args.dt),
        traj_len=int(args.traj_len),
        k_train=int(args.k_train),
        h_min=float(args.h_min),

        truth_j_mode=str(args.truth_j_mode),
        truth_g0=float(args.truth_g0),

        rollout_steps=int(args.rollout_steps),
        n_eval_trajs=int(args.n_eval_trajs),

        batch_size=int(args.batch_size),
        n_epochs=int(args.n_epochs),
        lr=float(args.lr),
        weight_decay=float(args.weight_decay),
        clip_grad=float(args.clip_grad),

        hidden_g=int(args.hidden_g),
        layers_g=int(args.layers_g),
        hidden_c=int(args.hidden_c),
        layers_c=int(args.layers_c),
        c_min=float(args.c_min),
        unitmean_c=bool(args.unitmean_c),

        outdir=str(args.outdir),
    )

    seeds = parse_int_list(args.seeds)
    n_list = [int(args.ntraj_low), int(args.ntraj_high)]

    j_modes = ["fixed", "state"]
    g_modes = ["diag", "full"]

    all_rows: List[Dict[str,float]] = []
    for n_traj in n_list:
        for seed in seeds:
            for j_mode in j_modes:
                for g_mode in g_modes:
                    row = run_one_setting(cfg, seed=seed, n_traj=n_traj, j_mode=j_mode, g_mode=g_mode)
                    all_rows.append(row)

    agg = aggregate_results(all_rows)
    outroot = Path(cfg.outdir)
    outroot.mkdir(parents=True, exist_ok=True)
    with open(outroot / "aggregate.json", "w") as f:
        json.dump({"config": cfg.__dict__, "rows": all_rows, "aggregate": agg}, f, indent=2)

    print(f"[saved] {outroot / 'aggregate.json'}")


if __name__ == "__main__":
    main()