#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FRL vs Baseline (GNN, fair backbone) + equal-physical-time rollout
------------------------------------------------------------------
Goal: Make RMSE Ratio(256/64) ~ 1 by evaluating rollouts over the SAME
physical time horizon across resolutions (different per-res step counts).

- Same backbone for both models (Grid-GraphSAGE, residual head y=x+Δ, abs XY PE)
- ONLY three FRL extras vs Baseline:
  (1) multi-resolution training (32-from-64 + 64 vs. Baseline: 64 only)
  (2) sinusoidal Nyquist-normalized PE (FRL only)
  (3) spectral loss (FRL only)

Evaluation:
- Free rollout to a fixed physical-time horizon (e.g., 10*dt@64)
  For each resolution r with time step dt_r, we roll for steps_r ≈ horizon / dt_r.
- Summary prints the horizon-based RMSEs and RMSE Ratio 256/64.
"""

import math
import random
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -----------------------------
# Utilities
# -----------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -----------------------------
# Part 1: 2D Pseudospectral Solver with 2/3 Dealiasing
# -----------------------------

class ImprovedPseudoSpectralSolver2D:
    """2D vorticity-form incompressible Navier–Stokes (periodic) pseudospectral solver."""
    def __init__(self, resolution: int, domain_size: float = 2*math.pi, nu: float = 1e-2):
        self.N = int(resolution)
        self.L = float(domain_size)
        self.nu = float(nu)
        self.dx = self.L / self.N
        k1d = 2 * math.pi * np.fft.fftfreq(self.N, d=self.L / self.N)
        self.kx, self.ky = np.meshgrid(k1d, k1d, indexing='ij')
        self.k2 = self.kx**2 + self.ky**2
        self.k2[0, 0] = 1.0
        self.k_nyq = (math.pi * self.N / self.L)
        kcut = (2.0/3.0) * self.k_nyq
        self.dealias = (np.abs(self.kx) <= kcut) & (np.abs(self.ky) <= kcut)
        dt_diff = 0.5 * self.dx**2 / self.nu
        dt_adv  = 0.1 * self.dx
        self.dt = float(min(dt_diff, dt_adv))
        self.ikx = 1j * self.kx
        self.iky = 1j * self.ky

    def initial_vorticity(self, seed: Optional[int] = None) -> np.ndarray:
        rs = np.random.RandomState(seed) if seed is not None else np.random
        x = np.linspace(0.0, self.L, self.N, endpoint=False)
        y = np.linspace(0.0, self.L, self.N, endpoint=False)
        X, Y = np.meshgrid(x, y, indexing='ij')
        omega = 4.0 * np.sin(2 * X) * np.sin(2 * Y)
        for _ in range(6):
            cx = rs.uniform(0, self.L); cy = rs.uniform(0, self.L)
            r2 = ((X - cx + self.L/2) % self.L - self.L/2)**2 + ((Y - cy + self.L/2) % self.L - self.L/2)**2
            omega += 0.5 * np.exp(-r2 / (0.08 * self.L)**2) * (rs.rand() * 2 - 1)
        return omega.astype(np.float64)

    def velocity_from_vorticity_hat(self, omega_hat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        psi_hat = - omega_hat / self.k2
        u_hat =  self.iky * psi_hat
        v_hat = -self.ikx * psi_hat
        u = np.fft.ifft2(u_hat).real
        v = np.fft.ifft2(v_hat).real
        return u, v

    def rhs_hat(self, omega: np.ndarray) -> np.ndarray:
        omega_hat = np.fft.fft2(omega)
        u, v = self.velocity_from_vorticity_hat(omega_hat)
        omega_x = np.fft.ifft2(self.ikx * omega_hat).real
        omega_y = np.fft.ifft2(self.iky * omega_hat).real
        nonlinear = u * omega_x + v * omega_y
        nonlinear_hat = np.fft.fft2(nonlinear)
        nonlinear_hat = np.where(self.dealias, nonlinear_hat, 0.0)
        diffusion_hat = - self.nu * self.k2 * omega_hat
        rhs_hat = - nonlinear_hat + diffusion_hat
        return rhs_hat

    def step(self, omega: np.ndarray) -> np.ndarray:
        dt = self.dt
        k1 = self.rhs_hat(omega)
        k2 = self.rhs_hat(np.fft.ifft2(np.fft.fft2(omega) + 0.5*dt*k1).real)
        k3 = self.rhs_hat(np.fft.ifft2(np.fft.fft2(omega) + 0.5*dt*k2).real)
        k4 = self.rhs_hat(np.fft.ifft2(np.fft.fft2(omega) +     dt*k3).real)
        omega_hat = np.fft.fft2(omega) + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)
        omega_new = np.fft.ifft2(omega_hat).real
        return omega_new

    def generate_trajectory(self, n_steps: int, seed: Optional[int] = None) -> np.ndarray:
        omega = self.initial_vorticity(seed=seed)
        traj = [omega.copy()]
        for _ in range(n_steps):
            omega = self.step(omega); traj.append(omega.copy())
        return np.stack(traj, axis=0)  # [T+1, N, N]


# -----------------------------
# Part 2: Strict Anti-Aliased Downsampling
# -----------------------------

def fft_lowpass_downsample(field: np.ndarray, target_size: int) -> np.ndarray:
    """FFT low-pass (center crop) then IFFT to target_size."""
    N = field.shape[0]
    assert field.shape[0] == field.shape[1]
    assert target_size <= N and target_size > 0
    F = np.fft.fftshift(np.fft.fft2(field))
    mid = N // 2; half = target_size // 2
    sl = slice(mid - half, mid + half) if (target_size % 2 == 0) else slice(mid - half, mid + half + 1)
    F_cropped = F[sl, sl]
    f_coarse = np.fft.ifft2(np.fft.ifftshift(F_cropped))
    return f_coarse.real.astype(np.float32)


# -----------------------------
# Part 3: Datasets
# -----------------------------

class DirectOneStepDataset(Dataset):
    """(x_t, x_{t+1}) from direct trajectories at a fixed spatial resolution."""
    def __init__(self, trajectories: List[np.ndarray]):
        super().__init__()
        self.pairs = []
        for traj in trajectories:
            for t in range(traj.shape[0] - 1):
                x = traj[t].astype(np.float32)[None, ...]
                y = traj[t+1].astype(np.float32)[None, ...]
                self.pairs.append((x, y))
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx: int):
        x, y = self.pairs[idx]
        return torch.from_numpy(x).float(), torch.from_numpy(y).float()

class DownsampleFromTrajDataset(Dataset):
    """(x_t^↓, x_{t+1}^↓) by FFT low-pass downsampling per frame from higher-res trajs."""
    def __init__(self, trajectories: List[np.ndarray], target_resolution: int):
        super().__init__()
        self.H = int(target_resolution)
        self.pairs = []
        for traj in trajectories:
            coarse = [fft_lowpass_downsample(frame, self.H) for frame in traj]
            for t in range(len(coarse) - 1):
                x = coarse[t][None, ...]; y = coarse[t+1][None, ...]
                self.pairs.append((x, y))
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx: int):
        x, y = self.pairs[idx]
        return torch.from_numpy(x).float(), torch.from_numpy(y).float()


# -----------------------------
# Part 4: GNN Backbone (shared)
# -----------------------------

class GridGraphSAGELayer(nn.Module):
    """GraphSAGE-like layer on a periodic 2D grid using 4-neighbor message passing."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.self_lin  = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        self.neigh_lin = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up    = torch.roll(x, shifts=+1, dims=2)
        down  = torch.roll(x, shifts=-1, dims=2)
        left  = torch.roll(x, shifts=+1, dims=3)
        right = torch.roll(x, shifts=-1, dims=3)
        neigh = (up + down + left + right) / 4.0
        h = self.self_lin(x) + self.neigh_lin(neigh)
        return self.act(self.bn(h))

def abs_xy_posenc(H: int, W: int, device: torch.device) -> torch.Tensor:
    """Absolute XY PE in [0,1], shape [1,2,H,W]."""
    y = torch.linspace(0, 1, steps=H, device=device).view(H, 1).repeat(1, W)
    x = torch.linspace(0, 1, steps=W, device=device).view(1, W).repeat(H, 1)
    pe = torch.stack([x, y], dim=0).unsqueeze(0)  # [1,2,H,W]
    return pe


# -----------------------------
# Part 5: Models
# -----------------------------

class BaselineGNN(nn.Module):
    """Baseline GNN (fair): field(1)+absXY(2) -> GridSAGE -> residual head (zero-init)."""
    def __init__(self, hidden: int = 64, n_layers: int = 5):
        super().__init__()
        in_ch = 1 + 2
        self.in_proj = nn.Conv2d(in_ch, hidden, kernel_size=1)
        self.layers = nn.ModuleList([GridGraphSAGELayer(hidden, hidden) for _ in range(n_layers)])
        self.out_proj = nn.Conv2d(hidden, 1, kernel_size=1)
        nn.init.zeros_(self.out_proj.weight); nn.init.zeros_(self.out_proj.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, H, W = x.shape
        pe = abs_xy_posenc(H, W, device=x.device).repeat(B, 1, 1, 1)
        h = self.in_proj(torch.cat([x, pe], dim=1))
        for layer in self.layers:
            h = h + layer(h)
        delta = self.out_proj(h)
        return x + delta

class FRLGNN(nn.Module):
    """FRL GNN: field(1)+absXY(2)+Nyquist sin/cos(4*n_freqs) -> GridSAGE -> residual (zero-init)."""
    def __init__(self, hidden: int = 64, n_layers: int = 5, n_freqs: int = 8):
        super().__init__()
        self.n_freqs = int(n_freqs)
        pe_ch = 4 * self.n_freqs
        in_ch = 1 + 2 + pe_ch
        self.in_proj = nn.Conv2d(in_ch, hidden, kernel_size=1)
        self.layers = nn.ModuleList([GridGraphSAGELayer(hidden, hidden) for _ in range(n_layers)])
        self.out_proj = nn.Conv2d(hidden, 1, kernel_size=1)
        nn.init.zeros_(self.out_proj.weight); nn.init.zeros_(self.out_proj.bias)
    @staticmethod
    def nyq_pe(H: int, W: int, n_freqs: int, device: torch.device) -> torch.Tensor:
        i = torch.arange(H, device=device).view(H, 1).repeat(1, W)
        j = torch.arange(W, device=device).view(1, W).repeat(H, 1)
        chans = []
        for k in range(1, n_freqs+1):
            s = k / float(n_freqs)
            ax = math.pi * s * i
            ay = math.pi * s * j
            chans += [torch.sin(ax), torch.cos(ax), torch.sin(ay), torch.cos(ay)]
        return torch.stack(chans, dim=0).unsqueeze(0)  # [1,4*n_freqs,H,W]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, H, W = x.shape
        device = x.device
        pe_xy  = abs_xy_posenc(H, W, device=device).repeat(B, 1, 1, 1)
        pe_sin = self.nyq_pe(H, W, self.n_freqs, device=device).repeat(B, 1, 1, 1)
        h = self.in_proj(torch.cat([x, pe_xy, pe_sin], dim=1))
        for layer in self.layers:
            h = h + layer(h)
        delta = self.out_proj(h)
        return x + delta


# -----------------------------
# Part 6: Losses & Metrics
# -----------------------------

def make_radial_weight(h: int, w: int, alpha: float = 1.0, device: Optional[torch.device] = None) -> torch.Tensor:
    if device is None:
        device = torch.device('cpu')
    fy = torch.fft.fftfreq(h, d=1.0, device=device).view(-1, 1).repeat(1, w)
    fx = torch.fft.fftfreq(w, d=1.0, device=device).view(1, -1).repeat(h, 1)
    r = torch.sqrt(fx**2 + fy**2)
    r_max = r.max()
    wmap = torch.where(r_max > 0, (r / (r_max + 1e-12))**alpha, torch.zeros_like(r))
    return wmap

def spectral_loss(pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
    """Weighted spectral amplitude MSE."""
    _, _, H, W = pred.shape
    device = pred.device
    Wmap = make_radial_weight(H, W, alpha=alpha, device=device)
    P = torch.fft.fft2(pred.squeeze(1))
    T = torch.fft.fft2(target.squeeze(1))
    diff = (torch.abs(P) - torch.abs(T))**2
    return (diff * Wmap).mean()

def evaluate_rollout_rmse_steps(model: nn.Module,
                                trajectories: List[np.ndarray],
                                steps: int,
                                device: torch.device) -> float:
    """Free-rollout RMSE for a given number of model steps."""
    model.eval()
    se_sum, count = 0.0, 0
    with torch.no_grad():
        for traj in trajectories:
            T = traj.shape[0] - 1
            H, W = traj.shape[1], traj.shape[2]
            for t0 in range(0, T - steps + 1):
                x = torch.from_numpy(traj[t0].astype(np.float32)).to(device).view(1,1,H,W)
                for k in range(1, steps+1):
                    pred = model(x)
                    gt   = torch.from_numpy(traj[t0 + k].astype(np.float32)).to(device).view(1,1,H,W)
                    se_sum += F.mse_loss(pred, gt, reduction='sum').item()
                    count  += gt.numel()
                    x = pred
    return math.sqrt(se_sum / max(1, count))


# -----------------------------
# Part 7: Training (Baseline vs FRL)
# -----------------------------

def evaluate_rmse_singlestep(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    """Used for early stopping selection (single-step @256)."""
    model.eval()
    mse_accum, count = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device); y = y.to(device)
            pred = model(x)
            mse_accum += F.mse_loss(pred, y, reduction='sum').item()
            count += y.numel()
    return math.sqrt(mse_accum / max(1, count))

def train_baseline_es(model: nn.Module,
                      loader_train: DataLoader,
                      loader_val256: DataLoader,
                      epochs: int,
                      device: torch.device,
                      lr: float = 5e-4,
                      patience: int = 20,
                      min_delta: float = 0.0) -> None:
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    best_val = float('inf'); best_state = None; no_improve = 0
    for ep in range(1, epochs+1):
        model.train(); running = 0.0; n = 0
        for x, y in loader_train:
            x = x.to(device); y = y.to(device)
            pred = model(x)
            loss = F.mse_loss(pred, y)
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            running += loss.item() * x.size(0); n += x.size(0)
        train_loss = running / max(1, n)
        val_rmse = evaluate_rmse_singlestep(model, loader_val256, device=device)
        print(f"[Baseline-GNN] Epoch {ep:03d}/{epochs}  train_loss={train_loss:.6f}  val256_RMSE(1-step)={val_rmse:.6f}")
        if val_rmse + min_delta < best_val:
            best_val = val_rmse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}; no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[Baseline-GNN] Early stopping at epoch {ep}. Best val256 RMSE(1-step)={best_val:.6f}.")
                break
    if best_state is not None:
        model.load_state_dict(best_state)

def train_frl_roundrobin_es(model: nn.Module,
                            loaders_train: Dict[int, DataLoader],
                            loader_val256: DataLoader,
                            epochs: int,
                            device: torch.device,
                            lr: float = 5e-4,
                            alpha_freq: float = 1.0,
                            lambda_freq: float = 0.05,
                            patience: int = 20,
                            min_delta: float = 0.0,
                            res_loss_weights: Optional[Dict[int,float]] = None,
                            lambda_warmup_epochs: int = 5) -> None:
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    best_val = float('inf'); best_state = None; no_improve = 0
    if res_loss_weights is None: res_loss_weights = {}
    for ep in range(1, epochs+1):
        model.train(); running = 0.0; n = 0
        cur_lambda = lambda_freq * min(1.0, ep / float(lambda_warmup_epochs))
        iters = {H: iter(ld) for H, ld in loaders_train.items()}
        exhausted = {H: False for H in loaders_train.keys()}; remaining = len(loaders_train)
        while remaining > 0:
            for H in list(iters.keys()):
                if exhausted[H]: continue
                try:
                    x, y = next(iters[H])
                except StopIteration:
                    exhausted[H] = True; remaining -= 1; continue
                x = x.to(device); y = y.to(device)
                pred = model(x)
                l_space = F.mse_loss(pred, y)
                l_freq  = spectral_loss(pred, y, alpha=alpha_freq)
                loss = l_space + cur_lambda * l_freq
                loss = loss * float(res_loss_weights.get(H, 1.0))
                opt.zero_grad(); loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                running += loss.item() * x.size(0); n += x.size(0)
        train_loss = running / max(1, n)
        val_rmse = evaluate_rmse_singlestep(model, loader_val256, device=device)
        print(f"[FRL-GNN] Epoch {ep:03d}/{epochs}  train_loss={train_loss:.6f}  val256_RMSE(1-step)={val_rmse:.6f}")
        if val_rmse + min_delta < best_val:
            best_val = val_rmse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}; no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[FRL-GNN] Early stopping at epoch {ep}. Best val256 RMSE(1-step)={best_val:.6f}.")
                break
    if best_state is not None:
        model.load_state_dict(best_state)


# -----------------------------
# Part 8: Driver
# -----------------------------

def downsample_trajs(trajs: List[np.ndarray], target_res: int) -> List[np.ndarray]:
    """Downsample every frame of trajectories to target_res using FFT low-pass."""
    out = []
    for traj in trajs:
        frames = [fft_lowpass_downsample(frame, target_res) for frame in traj]
        out.append(np.stack(frames, axis=0))
    return out

def main():
    set_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # ===== Config =====
    HI_RES_TRUTH = 256
    BASELINE_TRAIN_RES = 64

    # Training/validation trajectory lengths (single-step supervision)
    N_STEPS_TRAIN = 12
    N_STEPS_VAL   = 12

    EPOCHS_BASE = 100
    EPOCHS_FRL  = 100
    BATCH_SIZE  = 16
    LR          = 5e-4

    # FRL components
    ALPHA_FREQ  = 4
    LAMBDA_FREQ = 0.5
    RES_LOSS_WEIGHTS = {32: 0, 64: 1.0}
    PATIENCE    = 20
    MIN_DELTA   = 0.0

    # ===== Solvers =====
    solver64  = ImprovedPseudoSpectralSolver2D(64,  domain_size=2*math.pi, nu=1e-2)
    solver128 = ImprovedPseudoSpectralSolver2D(128, domain_size=2*math.pi, nu=1e-2)
    solver256 = ImprovedPseudoSpectralSolver2D(256, domain_size=2*math.pi, nu=1e-2)
    print(f"Solver64 : N={solver64.N},  dt={solver64.dt:.3e}")
    print(f"Solver128: N={solver128.N}, dt={solver128.dt:.3e}")
    print(f"Solver256: N={solver256.N}, dt={solver256.dt:.3e}")

    # ===== Physical rollout horizon (match "10 steps at 64") =====
    rollout_horizon = 10 * solver64.dt     # same physical time for all resolutions
    steps_64  = max(1, int(round(rollout_horizon / solver64.dt)))
    steps_128 = max(1, int(round(rollout_horizon / solver128.dt)))
    steps_256 = max(1, int(round(rollout_horizon / solver256.dt)))
    steps_32  = steps_64  # our 32 test is spatial downsample of 64 test -> same dt as 64

    max_steps_needed = max(steps_32, steps_64, steps_128, steps_256)
    # Test trajectories must be at least (max_steps_needed + 1) frames
    N_STEPS_TEST = max(50, max_steps_needed + 2)  # safety margin

    # ===== Trajectories =====
    # Train @64
    train_trajs_64 = [solver64.generate_trajectory(N_STEPS_TRAIN, seed=100+i) for i in range(8)]
    for i in range(8): print(f"Generated train64 trajectory {i+1}/8")
    # Val @256
    val_trajs_256  = [solver256.generate_trajectory(N_STEPS_VAL, seed=400+i) for i in range(2)]
    for i in range(2): print(f"Generated val256 trajectory {i+1}/2")
    # Test @64/128/256 with enough frames for equal-horizon rollouts
    test_trajs_64  = [solver64.generate_trajectory(N_STEPS_TEST,  seed=200+i) for i in range(2)]
    for i in range(2): print(f"Generated test64 trajectory {i+1}/2")
    test_trajs_128 = [solver128.generate_trajectory(N_STEPS_TEST, seed=300+i) for i in range(2)]
    for i in range(2): print(f"Generated test128 trajectory {i+1}/2")
    test_trajs_256 = [solver256.generate_trajectory(N_STEPS_TEST, seed=500+i) for i in range(2)]
    for i in range(2): print(f"Generated test256 trajectory {i+1}/2")
    test_trajs_32  = downsample_trajs(test_trajs_64, target_res=32)  # spatially downsampled; same dt as 64

    # ===== Datasets & Loaders =====
    train_loader_64 = DataLoader(DirectOneStepDataset(train_trajs_64), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    frl_train_loader_64 = DataLoader(DirectOneStepDataset(train_trajs_64), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    frl_train_loader_32 = DataLoader(DownsampleFromTrajDataset(train_trajs_64, target_resolution=32), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader_256 = DataLoader(DirectOneStepDataset(val_trajs_256), batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    print("Datasets built.")

    # ===== Train Baseline (64 only) =====
    baseline = BaselineGNN(hidden=64, n_layers=5).to(device)
    train_baseline_es(baseline, train_loader_64, val_loader_256,
                      epochs=EPOCHS_BASE, device=device, lr=LR,
                      patience=PATIENCE, min_delta=MIN_DELTA)

    # ===== Train FRL (32-from-64 + 64) =====
    frl = FRLGNN(hidden=64, n_layers=5, n_freqs=8).to(device)
    train_frl_roundrobin_es(
        frl,
        loaders_train={32: frl_train_loader_32, 64: frl_train_loader_64},
        loader_val256=val_loader_256,
        epochs=EPOCHS_FRL,
        device=device,
        lr=LR,
        alpha_freq=ALPHA_FREQ,
        lambda_freq=LAMBDA_FREQ,
        patience=PATIENCE,
        min_delta=MIN_DELTA,
        res_loss_weights=RES_LOSS_WEIGHTS,
        lambda_warmup_epochs=5,
    )

    # ===== Equal-physical-time rollout RMSE =====
    base_rmse_eqT = {
        32:  evaluate_rollout_rmse_steps(baseline, test_trajs_32,  steps=steps_32,  device=device),
        64:  evaluate_rollout_rmse_steps(baseline, test_trajs_64,  steps=steps_64,  device=device),
        128: evaluate_rollout_rmse_steps(baseline, test_trajs_128, steps=steps_128, device=device),
        256: evaluate_rollout_rmse_steps(baseline, test_trajs_256, steps=steps_256, device=device),
    }
    frl_rmse_eqT = {
        32:  evaluate_rollout_rmse_steps(frl, test_trajs_32,  steps=steps_32,  device=device),
        64:  evaluate_rollout_rmse_steps(frl, test_trajs_64,  steps=steps_64,  device=device),
        128: evaluate_rollout_rmse_steps(frl, test_trajs_128, steps=steps_128, device=device),
        256: evaluate_rollout_rmse_steps(frl, test_trajs_256, steps=steps_256, device=device),
    }

    # ===== Summary =====
    print("\n=== Summary (Equal-physical-time rollout RMSE; label = Train/Test) ===")
    print(f"(Horizon ≈ 10 steps @64  → steps: 32={steps_32}, 64={steps_64}, 128={steps_128}, 256={steps_256})")
    print("[Baseline-GNN]")
    print(f"  32 (Test)        : {base_rmse_eqT[32]:.6f}")
    print(f"  64 (Train/Test)  : {base_rmse_eqT[64]:.6f}")
    print(f"  128 (Test)       : {base_rmse_eqT[128]:.6f}")
    print(f"  256 (Test)       : {base_rmse_eqT[256]:.6f}")
    base_ratio = base_rmse_eqT[256] / max(1e-12, base_rmse_eqT[64])
    print(f"  RMSE Ratio 256/64: {base_ratio:.3f}")

    print("[FRL-GNN]")
    print(f"  32 (Train/Test)  : {frl_rmse_eqT[32]:.6f}")
    print(f"  64 (Train/Test)  : {frl_rmse_eqT[64]:.6f}")
    print(f"  128 (Test)       : {frl_rmse_eqT[128]:.6f}")
    print(f"  256 (Test)       : {frl_rmse_eqT[256]:.6f}")
    frl_ratio = frl_rmse_eqT[256] / max(1e-12, frl_rmse_eqT[64])
    print(f"  RMSE Ratio 256/64: {frl_ratio:.3f}")


if __name__ == "__main__":
    main()
