{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sX6E_OJoS25a"
      },
      "source": [
        "# ∂-NO Benchmark: Experiments\n",
        "\n",
        "## Supplementary Material\n",
        "\n",
        "> **Note:** This notebook is provided for illustration purposes only.  \n",
        "> Full code and reproducible experiments will be available on GitHub upon acceptance.\n",
        ">\n",
        "> Results may vary with hardware configurations.  \n",
        "> **Tip:** Lower batch sizes generally yield better metrics but require longer training time.\n",
        "> **Datasets:** Please place them in ./data folder\n",
        "\n",
        "---\n",
        "\n",
        "### Overview\n",
        "\n",
        "This notebook is a standalone supplementary material comparing baseline neural operators with their ∂-augmented variants.\n",
        "\n",
        "### Datasets\n",
        "\n",
        "| Dataset | PDE | Domain | Resolution | Train/Test |\n",
        "|---------|-----|--------|------------|------------|\n",
        "| Burgers (ν=0.1) | Viscous Burgers | 1D | 8192 → 1024 | 1000/200 |\n",
        "| Darcy | Elliptic (steady-state) | 2D | 421×421 → 85×85 | 1000/200 |\n",
        "| Navier-Stokes (ν=10⁻⁵) | Incompressible turbulence | 2D+T | 64×64×20 | 1000/200 |\n",
        "\n",
        "### Models\n",
        "\n",
        "| Model | Description | ∂-Augmentation |\n",
        "|-------|-------------|----------------|\n",
        "| FNO | Fourier Neural Operator | ∂-FNO |\n",
        "| Transolver | Attention-based (SOTA) | ∂-Transolver |\n",
        "\n",
        "### ∂-NO Configuration\n",
        "\n",
        "| Setting | 1D Problems | 2D Problems |\n",
        "|---------|-------------|-------------|\n",
        "| Derivative features | $\\mathfrak{D}_x^{\\beta_1} u$, $\\mathfrak{D}_x^{\\beta_2} u$ | $\\mathfrak{D}_x^{\\beta_{x,1}}$, $\\mathfrak{D}_x^{\\beta_{x,2}}$, $\\mathfrak{D}_y^{\\beta_{y,1}}$, $\\mathfrak{D}_y^{\\beta_{y,2}}$, mixed |\n",
        "| Initialization | β₁=1.0, β₂=2.0 | Same per direction |\n",
        "| GL kernel size | K=16 | K=16 |\n",
        "\n",
        "### Training Configuration\n",
        "\n",
        "| Parameter | Value |\n",
        "|-----------|-------|\n",
        "| Optimizer | AdamW |\n",
        "| LR (backbone) | 10⁻³ |\n",
        "| LR (β parameters) | 10⁻² (10× higher) |\n",
        "| Batch size | 8 (FNO) / 4 (Transolver) |\n",
        "| Normalizer | UnitGaussianNormalizer (Darcy) |\n",
        "| NS training | Autoregressive |\n",
        "\n",
        "### Experiment Matrix\n",
        "\n",
        "| | FNO | ∂-FNO | Transolver | ∂-Transolver |\n",
        "|---|:---:|:---:|:---:|:---:|\n",
        "| Burgers (ν=0.1) | ✓ | ✓ | ✓ | ✓ |\n",
        "| Darcy | ✓ | ✓ | ✓ | ✓ |\n",
        "| Navier-Stokes | ✓ | ✓ | ✓ | ✓ |"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dlOfVUe7S25c",
        "cellView": "form"
      },
      "source": [
        "# @title 1. IMPORTS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import math\n",
        "import os\n",
        "import time\n",
        "import csv\n",
        "import glob\n",
        "from datetime import datetime\n",
        "from typing import Dict, List, Tuple, Optional, Union, Any, Callable\n",
        "from dataclasses import dataclass, field\n",
        "\n",
        "try:\n",
        "    import scipy.io\n",
        "    HAS_SCIPY = True\n",
        "except ImportError:\n",
        "    HAS_SCIPY = False\n",
        "\n",
        "try:\n",
        "    import gdown\n",
        "    HAS_GDOWN = True\n",
        "except ImportError:\n",
        "    HAS_GDOWN = False\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VwB-1sDyS25d",
        "cellView": "form"
      },
      "source": [
        "# @title 2. GOOGLE DRIVE IDS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "GDRIVE_IDS = {\n",
        "    'burgers_r10':   '16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe',\n",
        "    'darcy_421':     '1Z1uxG9R8AdAGJprG5STcphysjm56_0Jf',\n",
        "    'ns_v1e5':       '1qO46jjKooiymGCjtfKxb9fUfa74fc68Z',\n",
        "}\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Inu2P0IcS25d",
        "cellView": "form"
      },
      "source": [
        "# @title 3. CONFIGURATION DATACLASS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "@dataclass\n",
        "class SpectraConfig:\n",
        "    \"\"\"Configuration for DELNO-Spectra v7.\"\"\"\n",
        "\n",
        "    # ─── Dataset ─────────────────────────────────────────────────────────────────\n",
        "    dataset: str = 'burgers_r10'\n",
        "    n_train: int = 1000\n",
        "    n_test: int = 200\n",
        "    resolution: int = 1024\n",
        "\n",
        "    # ─── Model ───────────────────────────────────────────────────────────────────\n",
        "    model_type: str = 'fno'  # 'fno' or 'transolver'\n",
        "\n",
        "    # ─── DELNO-Spectra Features ──────────────────────────────────────────────────\n",
        "    max_order: int = 2\n",
        "    m_max: float = 2.0\n",
        "    include_cross: bool = False\n",
        "    include_grid: bool = True\n",
        "    K: int = 16  # GL kernel size\n",
        "\n",
        "    # ─── FNO Backbone ────────────────────────────────────────────────────────────\n",
        "    width: int = 64\n",
        "    modes: int = 12\n",
        "    n_layers: int = 4\n",
        "\n",
        "    # ─── Transolver ──────────────────────────────────────────────────────────────\n",
        "    n_heads: int = 8\n",
        "    slice_num: int = 32\n",
        "    mlp_ratio: int = 2\n",
        "    dropout: float = 0.0\n",
        "\n",
        "    # ─── Training ────────────────────────────────────────────────────────────────\n",
        "    epochs: int = 5\n",
        "    batch_size: int = 8\n",
        "    lr: float = 1e-3\n",
        "    lr_beta: float = 1e-2\n",
        "    lr_heff: float = 1e-3\n",
        "    weight_decay: float = 1e-5\n",
        "    scheduler: str = 'onecycle'  # 'onecycle', 'step', 'cosine'\n",
        "    pct_start: float = 0.3\n",
        "    step_size: int = 100\n",
        "    gamma: float = 0.5\n",
        "    grad_clip: float = 1.0\n",
        "\n",
        "    # ─── Logging ─────────────────────────────────────────────────────────────────\n",
        "    log_dir: str = './logs'\n",
        "    log_every: int = 10\n",
        "    eval_every: int = 100\n",
        "\n",
        "    # ─── Misc ────────────────────────────────────────────────────────────────────\n",
        "    seed: int = 42\n",
        "    device: str = 'cuda'\n",
        "    preset: str = 'thuml'\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 26,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nby5tultS25d",
        "cellView": "form"
      },
      "source": [
        "# @title 4. MODULAR CONFIGURATION SYSTEM\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "# Dataset-specific parameters\n",
        "DATASET_PARAMS = {\n",
        "    'burgers_r10': {\n",
        "        'n_train': 1000, 'n_test': 200, 'resolution': 1024,\n",
        "    },\n",
        "    'burgers_r1000': {\n",
        "        'n_train': 1000, 'n_test': 200, 'resolution': 1024,\n",
        "    },\n",
        "    'darcy': {\n",
        "        'n_train': 1000, 'n_test': 200, 'resolution': 85,\n",
        "    },\n",
        "    'ns_v1e3': {\n",
        "        'n_train': 1000, 'n_test': 200, 'resolution': 64,\n",
        "    },\n",
        "    'ns_v1e5': {\n",
        "        'n_train': 1000, 'n_test': 200, 'resolution': 64,\n",
        "    },\n",
        "}\n",
        "\n",
        "# Model-specific parameters\n",
        "MODEL_PARAMS = {\n",
        "    'fno': {\n",
        "        'burgers_r10':  {'width': 64,  'modes': 12, 'n_layers': 4},\n",
        "        'burgers_r1000': {'width': 64,  'modes': 12, 'n_layers': 4},\n",
        "        'darcy':         {'width': 64,  'modes': 12, 'n_layers': 4},  # original uses 64\n",
        "        'ns_v1e3':       {'width': 64,  'modes': 12, 'n_layers': 4},\n",
        "        'ns_v1e5':       {'width': 64,  'modes': 12, 'n_layers': 4},\n",
        "    },\n",
        "    'transolver': {\n",
        "        'burgers_r10':  {'width': 128, 'n_layers': 8, 'n_heads': 8, 'slice_num': 64},\n",
        "        'burgers_r1000': {'width': 128, 'n_layers': 8, 'n_heads': 8, 'slice_num': 64},\n",
        "        'darcy':         {'width': 128, 'n_layers': 8, 'n_heads': 8, 'slice_num': 64},\n",
        "        'ns_v1e3':       {'width': 256, 'n_layers': 8, 'n_heads': 8, 'slice_num': 32},\n",
        "        'ns_v1e5':       {'width': 256, 'n_layers': 8, 'n_heads': 8, 'slice_num': 32},\n",
        "    },\n",
        "}\n",
        "\n",
        "# Training preset parameters\n",
        "PRESET_PARAMS = {\n",
        "    'original': {\n",
        "        'epochs': 500, 'batch_size': 20, 'lr': 1e-3, 'weight_decay': 1e-4,\n",
        "        'scheduler': 'step', 'step_size': 100, 'gamma': 0.5,\n",
        "    },\n",
        "    'thuml': {\n",
        "        'epochs': 500, 'batch_size': 8, 'lr': 1e-3, 'weight_decay': 1e-5,\n",
        "        'scheduler': 'onecycle', 'pct_start': 0.3,\n",
        "    },\n",
        "    'custom': {\n",
        "        'epochs': 1000, 'batch_size': 8, 'lr': 1e-3, 'weight_decay': 1e-5,\n",
        "        'scheduler': 'step', 'step_size': 200, 'gamma': 0.5,\n",
        "    },\n",
        "}\n",
        "\n",
        "# Dataset-specific overrides per preset\n",
        "DATASET_PRESET_OVERRIDES = {\n",
        "    ('darcy', 'thuml'): {'width': 128, 'batch_size': 4},\n",
        "    ('ns_v1e3', 'thuml'): {'lr': 5e-4, 'batch_size': 20, 'scheduler': 'step'},\n",
        "    ('ns_v1e3', 'original'): {'lr': 1e-3},\n",
        "    ('ns_v1e3', 'custom'): {'lr': 5e-4, 'batch_size': 20},\n",
        "    ('ns_v1e5', 'thuml'): {'lr': 5e-4, 'batch_size': 20, 'scheduler': 'step'},\n",
        "    ('ns_v1e5', 'original'): {'lr': 1e-3},\n",
        "    ('ns_v1e5', 'custom'): {'lr': 5e-4, 'batch_size': 20},\n",
        "    ('burgers_r10', 'original'): {'modes': 16},\n",
        "    ('burgers_r1000', 'original'): {'modes': 16},\n",
        "}\n",
        "\n",
        "# Transolver-specific overrides\n",
        "TRANSOLVER_PRESET_OVERRIDES = {\n",
        "    ('ns_v1e3', 'thuml'): {'batch_size': 2},\n",
        "    ('ns_v1e5', 'thuml'): {'batch_size': 2},\n",
        "    ('ns_v1e3', 'custom'): {'batch_size': 2},\n",
        "    ('ns_v1e5', 'custom'): {'batch_size': 2},\n",
        "}\n",
        "\n",
        "\n",
        "def get_config(dataset: str, model: str, preset: str) -> SpectraConfig:\n",
        "    \"\"\"\n",
        "    Get configuration by composing dataset, model, and preset parameters.\n",
        "\n",
        "    Args:\n",
        "        dataset: 'burgers_r100', 'burgers_r1000', 'darcy', 'ns_v1e3', 'ns_v1e5'\n",
        "        model: 'fno', 'transolver'\n",
        "        preset: 'original', 'thuml', 'custom'\n",
        "\n",
        "    Returns:\n",
        "        SpectraConfig with merged parameters\n",
        "    \"\"\"\n",
        "    if dataset not in DATASET_PARAMS:\n",
        "        raise ValueError(f\"Unknown dataset: {dataset}\")\n",
        "    if model not in MODEL_PARAMS:\n",
        "        raise ValueError(f\"Unknown model: {model}\")\n",
        "    if preset not in PRESET_PARAMS:\n",
        "        raise ValueError(f\"Unknown preset: {preset}\")\n",
        "\n",
        "    # Start with defaults\n",
        "    params = {'dataset': dataset, 'model_type': model, 'preset': preset}\n",
        "\n",
        "    # Add dataset params\n",
        "    params.update(DATASET_PARAMS[dataset])\n",
        "\n",
        "    # Add model params\n",
        "    params.update(MODEL_PARAMS[model][dataset])\n",
        "\n",
        "    # Add preset params\n",
        "    params.update(PRESET_PARAMS[preset])\n",
        "\n",
        "    # Apply dataset-preset overrides\n",
        "    key = (dataset, preset)\n",
        "    if key in DATASET_PRESET_OVERRIDES:\n",
        "        params.update(DATASET_PRESET_OVERRIDES[key])\n",
        "\n",
        "    # Apply transolver-specific overrides\n",
        "    if model == 'transolver' and key in TRANSOLVER_PRESET_OVERRIDES:\n",
        "        params.update(TRANSOLVER_PRESET_OVERRIDES[key])\n",
        "\n",
        "    return SpectraConfig(**params)\n",
        "\n",
        "\n",
        "def list_configs() -> List[str]:\n",
        "    \"\"\"List all available configuration names.\"\"\"\n",
        "    configs = []\n",
        "    for dataset in DATASET_PARAMS.keys():\n",
        "        for model in MODEL_PARAMS.keys():\n",
        "            for preset in PRESET_PARAMS.keys():\n",
        "                configs.append(f\"{dataset}_{model}_{preset}\")\n",
        "    return configs\n",
        "\n",
        "\n",
        "# Legacy compatibility\n",
        "CONFIGS = {name: lambda d=name.rsplit('_', 2)[0], m=name.rsplit('_', 2)[1], p=name.rsplit('_', 2)[2]: get_config(d, m, p)\n",
        "           for name in list_configs()}\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iZUwSqytS25e",
        "cellView": "form"
      },
      "source": [
        "# @title 5.UTILITIES\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def count_params(model: nn.Module) -> int:\n",
        "    \"\"\"Count trainable parameters.\"\"\"\n",
        "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "def rel_l2_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
        "    \"\"\"Relative L2 loss.\"\"\"\n",
        "    pred_flat = pred.reshape(pred.shape[0], -1)\n",
        "    target_flat = target.reshape(target.shape[0], -1)\n",
        "    return (torch.norm(pred_flat - target_flat, dim=-1) /\n",
        "            (torch.norm(target_flat, dim=-1) + 1e-8)).mean()\n",
        "\n",
        "\n",
        "def l2_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
        "    \"\"\"MSE loss.\"\"\"\n",
        "    return F.mse_loss(pred, target)\n",
        "\n",
        "\n",
        "def setup_device(cfg: SpectraConfig) -> torch.device:\n",
        "    \"\"\"Setup compute device.\"\"\"\n",
        "    if cfg.device == 'cuda' and torch.cuda.is_available():\n",
        "        torch.backends.cudnn.benchmark = True\n",
        "        torch.backends.cuda.matmul.allow_tf32 = True\n",
        "        torch.backends.cudnn.allow_tf32 = True\n",
        "        device = torch.device('cuda')\n",
        "        print(f\"Device: {torch.cuda.get_device_name(0)}\")\n",
        "    else:\n",
        "        device = torch.device('cpu')\n",
        "        print(\"Device: CPU\")\n",
        "    return device\n",
        "\n",
        "\n",
        "def set_seed(seed: int):\n",
        "    \"\"\"Set random seeds for reproducibility.\"\"\"\n",
        "    torch.manual_seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "\n",
        "class UnitGaussianNormalizer:\n",
        "    \"\"\"Normalize to zero mean and unit variance.\"\"\"\n",
        "    def __init__(self, x: torch.Tensor, eps: float = 1e-5):\n",
        "        self.mean = x.mean(dim=0, keepdim=True)\n",
        "        self.std = x.std(dim=0, keepdim=True)\n",
        "        self.eps = eps\n",
        "\n",
        "    def encode(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        return (x - self.mean) / (self.std + self.eps)\n",
        "\n",
        "    def decode(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        return x * (self.std + self.eps) + self.mean\n",
        "\n",
        "    def to(self, device: torch.device) -> 'UnitGaussianNormalizer':\n",
        "        self.mean = self.mean.to(device)\n",
        "        self.std = self.std.to(device)\n",
        "        return self\n",
        "\n",
        "\n",
        "def softplus_inverse(x: float) -> float:\n",
        "    \"\"\"Inverse of softplus: log(exp(x) - 1).\"\"\"\n",
        "    return math.log(math.exp(x) - 1) if x > 0 else -10.0\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rCpspl27S25e",
        "cellView": "form"
      },
      "source": [
        "# @title 6. CSV LOGGER\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class CSVLogger:\n",
        "    \"\"\"\n",
        "    CSV Logger for training metrics.\n",
        "\n",
        "    Logs: epoch, train_loss, test_loss, rel_l2, lr, time_s, betas, h_effs, scales\n",
        "\n",
        "    For 1D max_order=2: 2 betas, 2 h_effs, 2 scales\n",
        "    For 2D max_order=2: 6 betas, 5 h_effs, 5 scales\n",
        "    \"\"\"\n",
        "    def __init__(self, log_dir: str, experiment_name: str, cfg: SpectraConfig, is_2d: bool = False):\n",
        "        os.makedirs(log_dir, exist_ok=True)\n",
        "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
        "        self.csv_path = os.path.join(log_dir, f\"{experiment_name}_{timestamp}.csv\")\n",
        "        self.cfg = cfg\n",
        "        self.is_2d = is_2d\n",
        "\n",
        "        self.fields = ['epoch', 'train_loss', 'test_loss', 'rel_l2', 'lr', 'time_s']\n",
        "\n",
        "        # Determine number of betas and h_effs\n",
        "        if is_2d and cfg.max_order == 2:\n",
        "            self.n_betas, self.n_heffs = 6, 5\n",
        "            self.deriv_names = ['Dx', 'Dy', 'Dxx', 'Dyy', 'Dxy_x', 'Dxy_y']\n",
        "            self.heff_names = ['Dx', 'Dy', 'Dxx', 'Dyy', 'Dxy']\n",
        "        elif is_2d and cfg.max_order == 1:\n",
        "            self.n_betas, self.n_heffs = 2, 2\n",
        "            self.deriv_names = ['Dx', 'Dy']\n",
        "            self.heff_names = ['Dx', 'Dy']\n",
        "        else:\n",
        "            self.n_betas = self.n_heffs = cfg.max_order\n",
        "            self.deriv_names = [f'D{i+1}' for i in range(cfg.max_order)]\n",
        "            self.heff_names = self.deriv_names\n",
        "\n",
        "        for name in self.deriv_names:\n",
        "            self.fields.append(f'beta_{name}')\n",
        "        for name in self.heff_names:\n",
        "            self.fields.extend([f'h_eff_{name}', f'scale_{name}'])\n",
        "\n",
        "        with open(self.csv_path, 'w', newline='') as f:\n",
        "            csv.DictWriter(f, fieldnames=self.fields).writeheader()\n",
        "\n",
        "        print(f\"  Logging to: {self.csv_path}\")\n",
        "\n",
        "    def log(self, epoch: int, train_loss: float, test_loss: float, rel_l2: float,\n",
        "            lr: float, time_s: float, spectra_params: Optional[Dict] = None):\n",
        "        \"\"\"Log a single epoch's metrics.\"\"\"\n",
        "        row = {\n",
        "            'epoch': epoch, 'train_loss': f'{train_loss:.6f}',\n",
        "            'test_loss': f'{test_loss:.6f}', 'rel_l2': f'{rel_l2:.6f}',\n",
        "            'lr': f'{lr:.2e}', 'time_s': f'{time_s:.1f}'\n",
        "        }\n",
        "\n",
        "        if spectra_params and 'betas' in spectra_params:\n",
        "            betas = spectra_params['betas']\n",
        "            h_effs = spectra_params['h_effs']\n",
        "            scales = spectra_params['scales']\n",
        "\n",
        "            for i, name in enumerate(self.deriv_names):\n",
        "                if i < len(betas):\n",
        "                    row[f'beta_{name}'] = f'{betas[i]:.4f}'\n",
        "\n",
        "            for i, name in enumerate(self.heff_names):\n",
        "                if i < len(h_effs):\n",
        "                    row[f'h_eff_{name}'] = f'{h_effs[i]:.6f}'\n",
        "                    row[f'scale_{name}'] = f'{scales[i]:.4f}'\n",
        "        else:\n",
        "            for name in self.deriv_names:\n",
        "                row[f'beta_{name}'] = ''\n",
        "            for name in self.heff_names:\n",
        "                row[f'h_eff_{name}'] = row[f'scale_{name}'] = ''\n",
        "\n",
        "        with open(self.csv_path, 'a', newline='') as f:\n",
        "            csv.DictWriter(f, fieldnames=self.fields).writerow(row)\n",
        "\n",
        "    def close(self, summary: Dict = None):\n",
        "        if summary:\n",
        "            print(f\"  Final: rel_l2={summary.get('best_rel_l2', 0)*100:.3f}%\")\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xVyj1KUDS25e",
        "cellView": "form"
      },
      "source": [
        "# @title 7. GRÜNWALD-LETNIKOV FRACTIONAL DERIVATIVE\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def gl_weights(beta: torch.Tensor, K: int) -> torch.Tensor:\n",
        "    \"\"\"GL weights via stable cumulative product.\"\"\"\n",
        "    device, dtype = beta.device, beta.dtype\n",
        "    k = torch.arange(1, K, dtype=dtype, device=device)\n",
        "    ratios = (k - 1 - beta) / k\n",
        "    cumprods = torch.cumprod(ratios, dim=0)\n",
        "    return torch.cat([torch.ones(1, dtype=dtype, device=device), cumprods])\n",
        "\n",
        "\n",
        "def gl_diff_1d_raw(u: torch.Tensor, beta: torch.Tensor, K: int = 16,\n",
        "                   padding: str = 'circular') -> torch.Tensor:\n",
        "    \"\"\"1D GL fractional derivative WITHOUT scaling.\"\"\"\n",
        "    squeeze = u.dim() == 2\n",
        "    if squeeze:\n",
        "        u = u.unsqueeze(1)\n",
        "    B, C, N = u.shape\n",
        "    K = min(K, N)\n",
        "\n",
        "    w = gl_weights(beta, K)\n",
        "\n",
        "    # Backward difference\n",
        "    u_back = F.pad(u, (K-1, 0), mode=padding)\n",
        "    k_back = w.flip(0).view(1, 1, K).expand(C, 1, K)\n",
        "    d_back = F.conv1d(u_back, k_back, groups=C)[:, :, :N]\n",
        "\n",
        "    # Forward difference\n",
        "    u_fwd = F.pad(u, (0, K-1), mode=padding)\n",
        "    k_fwd = w.view(1, 1, K).expand(C, 1, K)\n",
        "    d_fwd = F.conv1d(u_fwd, k_fwd, groups=C)[:, :, :N]\n",
        "\n",
        "    result = 0.5 * (d_back + torch.cos(math.pi * beta) * d_fwd)\n",
        "    return result.squeeze(1) if squeeze else result\n",
        "\n",
        "\n",
        "def gl_diff_2d_raw(u: torch.Tensor, beta_x: torch.Tensor, beta_y: torch.Tensor,\n",
        "                   K: int = 16, padding: str = 'circular') -> torch.Tensor:\n",
        "    \"\"\"2D GL fractional derivative.\"\"\"\n",
        "    B, H, W = u.shape\n",
        "\n",
        "    if beta_x.abs() > 1e-8:\n",
        "        u_flat = u.reshape(B * H, W)\n",
        "        u_flat = gl_diff_1d_raw(u_flat, beta_x, K, padding)\n",
        "        u = u_flat.reshape(B, H, W)\n",
        "\n",
        "    if beta_y.abs() > 1e-8:\n",
        "        u_t = u.permute(0, 2, 1).reshape(B * W, H)\n",
        "        u_t = gl_diff_1d_raw(u_t, beta_y, K, padding)\n",
        "        u = u_t.reshape(B, W, H).permute(0, 2, 1)\n",
        "\n",
        "    return u\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 30,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DVKqHDb4S25e",
        "cellView": "form"
      },
      "source": [
        "# @title 8. SPECTRA FRACTIONAL DERIVATIVE MODULES\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class SpectraFracDeriv1D(nn.Module):\n",
        "    \"\"\"\n",
        "    1D Fractional derivative with learnable β and h_eff.\n",
        "\n",
        "    Scale formula: scale = h_eff^(-β)\n",
        "    h_eff initialized from: h = 2/modes (spectral grid spacing)\n",
        "    \"\"\"\n",
        "    def __init__(self, beta_init: float, m_max: float, modes: int, K: int = 16):\n",
        "        super().__init__()\n",
        "        self.m_max = m_max\n",
        "        self.K = K\n",
        "        self.modes = modes\n",
        "\n",
        "        # Learnable β (clamped to [0, m_max])\n",
        "        self.beta_raw = nn.Parameter(torch.tensor(beta_init))\n",
        "\n",
        "        # Learnable h_eff initialized from modes: h = 2/modes\n",
        "        h_init = 2.0 / modes\n",
        "        self.h_eff_raw = nn.Parameter(torch.tensor(softplus_inverse(h_init)))\n",
        "\n",
        "    @property\n",
        "    def beta(self) -> torch.Tensor:\n",
        "        return torch.clamp(self.beta_raw, 0.0, self.m_max)\n",
        "\n",
        "    @property\n",
        "    def h_eff(self) -> torch.Tensor:\n",
        "        return F.softplus(self.h_eff_raw)\n",
        "\n",
        "    @property\n",
        "    def scale(self) -> torch.Tensor:\n",
        "        return torch.pow(self.h_eff, -self.beta)\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        d_raw = gl_diff_1d_raw(u, self.beta, self.K)\n",
        "        return self.scale * d_raw\n",
        "\n",
        "    def get_params(self) -> Dict[str, float]:\n",
        "        return {'beta': self.beta.item(), 'h_eff': self.h_eff.item(), 'scale': self.scale.item()}\n",
        "\n",
        "\n",
        "class SpectraFracDeriv2D(nn.Module):\n",
        "    \"\"\"\n",
        "    2D Fractional derivative with learnable β_x, β_y and h_eff.\n",
        "\n",
        "    Scale formula: scale = h_x^(-β_x) * h_y^(-β_y)\n",
        "    For same modes in both directions: scale = (2/modes)^(-(β_x + β_y))\n",
        "    \"\"\"\n",
        "    def __init__(self, beta_init_x: float, beta_init_y: float, m_max: float,\n",
        "                 modes_x: int, modes_y: int, K: int = 16):\n",
        "        super().__init__()\n",
        "        self.m_max = m_max\n",
        "        self.K = K\n",
        "        self.modes_x = modes_x\n",
        "        self.modes_y = modes_y\n",
        "\n",
        "        self.beta_raw_x = nn.Parameter(torch.tensor(beta_init_x))\n",
        "        self.beta_raw_y = nn.Parameter(torch.tensor(beta_init_y))\n",
        "\n",
        "        # h_eff for x and y directions (can differ if modes differ)\n",
        "        h_init_x = 2.0 / modes_x\n",
        "        h_init_y = 2.0 / modes_y\n",
        "        self.h_eff_raw_x = nn.Parameter(torch.tensor(softplus_inverse(h_init_x)))\n",
        "        self.h_eff_raw_y = nn.Parameter(torch.tensor(softplus_inverse(h_init_y)))\n",
        "\n",
        "    @property\n",
        "    def beta_x(self) -> torch.Tensor:\n",
        "        return torch.clamp(self.beta_raw_x, 0.0, self.m_max)\n",
        "\n",
        "    @property\n",
        "    def beta_y(self) -> torch.Tensor:\n",
        "        return torch.clamp(self.beta_raw_y, 0.0, self.m_max)\n",
        "\n",
        "    @property\n",
        "    def h_eff_x(self) -> torch.Tensor:\n",
        "        return F.softplus(self.h_eff_raw_x)\n",
        "\n",
        "    @property\n",
        "    def h_eff_y(self) -> torch.Tensor:\n",
        "        return F.softplus(self.h_eff_raw_y)\n",
        "\n",
        "    @property\n",
        "    def scale(self) -> torch.Tensor:\n",
        "        \"\"\"scale = h_x^(-β_x) * h_y^(-β_y)\"\"\"\n",
        "        scale_x = torch.pow(self.h_eff_x, -self.beta_x) if self.beta_x.abs() > 1e-8 else torch.tensor(1.0)\n",
        "        scale_y = torch.pow(self.h_eff_y, -self.beta_y) if self.beta_y.abs() > 1e-8 else torch.tensor(1.0)\n",
        "        return scale_x * scale_y\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        d_raw = gl_diff_2d_raw(u, self.beta_x, self.beta_y, self.K)\n",
        "        return self.scale * d_raw\n",
        "\n",
        "    def get_params(self) -> Dict[str, float]:\n",
        "        return {\n",
        "            'beta_x': self.beta_x.item(), 'beta_y': self.beta_y.item(),\n",
        "            'h_eff_x': self.h_eff_x.item(), 'h_eff_y': self.h_eff_y.item(),\n",
        "            'scale': self.scale.item()\n",
        "        }\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 31,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BfABD9FVS25f",
        "cellView": "form"
      },
      "source": [
        "# @title 9. SPECTRA FEATURES LAYERS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class SpectraFeatures1D(nn.Module):\n",
        "    \"\"\"\n",
        "    1D Spectra feature layer.\n",
        "\n",
        "    For max_order=2: D1 (β=1.0), D2 (β=2.0) → 2 betas, 2 h_effs\n",
        "    \"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.max_order = cfg.max_order\n",
        "        self.include_cross = cfg.include_cross\n",
        "        self.include_grid = cfg.include_grid\n",
        "        self.modes = cfg.modes\n",
        "\n",
        "        # D1 init at β=1.0, D2 init at β=2.0, etc.\n",
        "        self.derivs = nn.ModuleList([\n",
        "            SpectraFracDeriv1D(float(order), cfg.m_max, cfg.modes, cfg.K)\n",
        "            for order in range(1, cfg.max_order + 1)\n",
        "        ])\n",
        "        self.deriv_names = [f'D{order}' for order in range(1, cfg.max_order + 1)]\n",
        "\n",
        "    @property\n",
        "    def n_channels(self) -> int:\n",
        "        n = 1 if self.include_grid else 0\n",
        "        n += 1 + self.max_order\n",
        "        if self.include_cross:\n",
        "            n += self.max_order\n",
        "        return n\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, N = u.shape\n",
        "        features = []\n",
        "\n",
        "        if self.include_grid:\n",
        "            x = torch.linspace(0, 1, N, device=u.device, dtype=u.dtype).unsqueeze(0).expand(B, -1)\n",
        "            features.append(x)\n",
        "\n",
        "        features.append(u)\n",
        "        derivs_computed = [d(u) for d in self.derivs]\n",
        "        features.extend(derivs_computed)\n",
        "\n",
        "        if self.include_cross:\n",
        "            features.extend([u * d for d in derivs_computed])\n",
        "\n",
        "        return torch.stack(features, dim=-1)\n",
        "\n",
        "    def get_params(self) -> Dict[str, Any]:\n",
        "        betas, h_effs, scales = [], [], []\n",
        "        for deriv in self.derivs:\n",
        "            p = deriv.get_params()\n",
        "            betas.append(p['beta'])\n",
        "            h_effs.append(p['h_eff'])\n",
        "            scales.append(p['scale'])\n",
        "        return {'betas': betas, 'h_effs': h_effs, 'scales': scales}\n",
        "\n",
        "    def beta_parameters(self) -> List[nn.Parameter]:\n",
        "        return [d.beta_raw for d in self.derivs]\n",
        "\n",
        "    def heff_parameters(self) -> List[nn.Parameter]:\n",
        "        return [d.h_eff_raw for d in self.derivs]\n",
        "\n",
        "\n",
        "class SpectraFeatures2D(nn.Module):\n",
        "    \"\"\"\n",
        "    2D Spectra feature layer.\n",
        "\n",
        "    For max_order=2:\n",
        "        - Dx (β_x=1.0), Dy (β_y=1.0)  → 2 betas, 2 h_effs\n",
        "        - Dxx (β_x=2.0), Dyy (β_y=2.0) → 2 betas, 2 h_effs\n",
        "        - Dxy (β_x=1.0, β_y=1.0)       → 2 betas, 1 h_eff (shared)\n",
        "\n",
        "    Total: 6 betas, 5 derivative operators (5 h_effs)\n",
        "    \"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.max_order = cfg.max_order\n",
        "        self.include_cross = cfg.include_cross\n",
        "        self.include_grid = cfg.include_grid\n",
        "        self.modes = cfg.modes\n",
        "\n",
        "        self.derivs = nn.ModuleList()\n",
        "        self.deriv_names = []\n",
        "\n",
        "        # Order 1: Dx (β=1,0), Dy (β=0,1)\n",
        "        if cfg.max_order >= 1:\n",
        "            self.derivs.append(SpectraFracDeriv2D(1.0, 0.0, cfg.m_max, cfg.modes, cfg.modes, cfg.K))\n",
        "            self.deriv_names.append('Dx')\n",
        "            self.derivs.append(SpectraFracDeriv2D(0.0, 1.0, cfg.m_max, cfg.modes, cfg.modes, cfg.K))\n",
        "            self.deriv_names.append('Dy')\n",
        "\n",
        "        # Order 2: Dxx (β=2,0), Dyy (β=0,2), Dxy (β=1,1)\n",
        "        if cfg.max_order >= 2:\n",
        "            self.derivs.append(SpectraFracDeriv2D(2.0, 0.0, cfg.m_max, cfg.modes, cfg.modes, cfg.K))\n",
        "            self.deriv_names.append('Dxx')\n",
        "            self.derivs.append(SpectraFracDeriv2D(0.0, 2.0, cfg.m_max, cfg.modes, cfg.modes, cfg.K))\n",
        "            self.deriv_names.append('Dyy')\n",
        "            self.derivs.append(SpectraFracDeriv2D(1.0, 1.0, cfg.m_max, cfg.modes, cfg.modes, cfg.K))\n",
        "            self.deriv_names.append('Dxy')\n",
        "\n",
        "    @property\n",
        "    def n_channels(self) -> int:\n",
        "        n = 2 if self.include_grid else 0\n",
        "        n += 1 + len(self.derivs)\n",
        "        if self.include_cross:\n",
        "            n += len(self.derivs)\n",
        "        return n\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, H, W = u.shape\n",
        "        device = u.device\n",
        "        features = []\n",
        "\n",
        "        if self.include_grid:\n",
        "            x = torch.linspace(0, 1, H, device=device, dtype=u.dtype)\n",
        "            y = torch.linspace(0, 1, W, device=device, dtype=u.dtype)\n",
        "            gx, gy = torch.meshgrid(x, y, indexing='ij')\n",
        "            features.extend([gx.unsqueeze(0).expand(B, -1, -1), gy.unsqueeze(0).expand(B, -1, -1)])\n",
        "\n",
        "        features.append(u)\n",
        "        derivs_computed = [d(u) for d in self.derivs]\n",
        "        features.extend(derivs_computed)\n",
        "\n",
        "        if self.include_cross:\n",
        "            features.extend([u * d for d in derivs_computed])\n",
        "\n",
        "        return torch.stack(features, dim=-1)\n",
        "\n",
        "    def get_params(self) -> Dict[str, Any]:\n",
        "        \"\"\"\n",
        "        Returns 6 betas, 5 h_effs, 5 scales for max_order=2.\n",
        "\n",
        "        Betas: [β_Dx, β_Dy, β_Dxx, β_Dyy, β_Dxy_x, β_Dxy_y]\n",
        "        H_effs/Scales: one per derivative operator [Dx, Dy, Dxx, Dyy, Dxy]\n",
        "        \"\"\"\n",
        "        betas, h_effs, scales = [], [], []\n",
        "\n",
        "        for deriv, name in zip(self.derivs, self.deriv_names):\n",
        "            p = deriv.get_params()\n",
        "            h_effs.append(p['h_eff_x'] if p['beta_x'] > 1e-8 else p['h_eff_y'])\n",
        "            scales.append(p['scale'])\n",
        "\n",
        "            if name == 'Dxy':\n",
        "                betas.extend([p['beta_x'], p['beta_y']])\n",
        "            elif name in ['Dx', 'Dxx']:\n",
        "                betas.append(p['beta_x'])\n",
        "            else:\n",
        "                betas.append(p['beta_y'])\n",
        "\n",
        "        return {'betas': betas, 'h_effs': h_effs, 'scales': scales}\n",
        "\n",
        "    def beta_parameters(self) -> List[nn.Parameter]:\n",
        "        \"\"\"Returns 6 parameters for max_order=2.\"\"\"\n",
        "        params = []\n",
        "        for d, name in zip(self.derivs, self.deriv_names):\n",
        "            if name == 'Dxy':\n",
        "                params.extend([d.beta_raw_x, d.beta_raw_y])\n",
        "            elif name in ['Dx', 'Dxx']:\n",
        "                params.append(d.beta_raw_x)\n",
        "            else:\n",
        "                params.append(d.beta_raw_y)\n",
        "        return params\n",
        "\n",
        "    def heff_parameters(self) -> List[nn.Parameter]:\n",
        "        \"\"\"Returns h_eff parameters (active ones only).\"\"\"\n",
        "        params = []\n",
        "        for d, name in zip(self.derivs, self.deriv_names):\n",
        "            if name == 'Dxy':\n",
        "                params.extend([d.h_eff_raw_x, d.h_eff_raw_y])\n",
        "            elif name in ['Dx', 'Dxx']:\n",
        "                params.append(d.h_eff_raw_x)\n",
        "            else:\n",
        "                params.append(d.h_eff_raw_y)\n",
        "        return params\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5Qn21R6sS25g",
        "cellView": "form"
      },
      "source": [
        "# @title 10. SPECTRAL CONVOLUTION LAYERS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class SpectralConv1d(nn.Module):\n",
        "    \"\"\"1D Spectral convolution.\"\"\"\n",
        "    def __init__(self, in_channels: int, out_channels: int, modes: int):\n",
        "        super().__init__()\n",
        "        self.in_channels = in_channels\n",
        "        self.out_channels = out_channels\n",
        "        self.modes = modes\n",
        "        self.scale = 1 / (in_channels * out_channels)\n",
        "        self.weights = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        B, C, N = x.shape\n",
        "        x_ft = torch.fft.rfft(x, dim=-1)\n",
        "        out_ft = torch.zeros(B, self.out_channels, N // 2 + 1, dtype=torch.cfloat, device=x.device)\n",
        "        modes = min(self.modes, N // 2 + 1)\n",
        "        out_ft[:, :, :modes] = torch.einsum('bix,iox->box', x_ft[:, :, :modes], self.weights[:, :, :modes])\n",
        "        return torch.fft.irfft(out_ft, n=N, dim=-1)\n",
        "\n",
        "\n",
        "class SpectralConv2d(nn.Module):\n",
        "    \"\"\"2D Spectral convolution.\"\"\"\n",
        "    def __init__(self, in_channels: int, out_channels: int, modes1: int, modes2: int):\n",
        "        super().__init__()\n",
        "        self.in_channels = in_channels\n",
        "        self.out_channels = out_channels\n",
        "        self.modes1 = modes1\n",
        "        self.modes2 = modes2\n",
        "        self.scale = 1 / (in_channels * out_channels)\n",
        "        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat))\n",
        "        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat))\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        B, C, H, W = x.shape\n",
        "        x_ft = torch.fft.rfft2(x, dim=(-2, -1))\n",
        "        out_ft = torch.zeros(B, self.out_channels, H, W // 2 + 1, dtype=torch.cfloat, device=x.device)\n",
        "        m1, m2 = min(self.modes1, H // 2), min(self.modes2, W // 2 + 1)\n",
        "        out_ft[:, :, :m1, :m2] = torch.einsum('bixy,ioxy->boxy', x_ft[:, :, :m1, :m2], self.weights1[:, :, :m1, :m2])\n",
        "        out_ft[:, :, -m1:, :m2] = torch.einsum('bixy,ioxy->boxy', x_ft[:, :, -m1:, :m2], self.weights2[:, :, :m1, :m2])\n",
        "        return torch.fft.irfft2(out_ft, s=(H, W), dim=(-2, -1))\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 33,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WYunXXQzS25g",
        "cellView": "form"
      },
      "source": [
        "# @title 11. FNO MODELS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class FNO1d(nn.Module):\n",
        "    \"\"\"1D Fourier Neural Operator.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.lift = nn.Linear(2, cfg.width)  # x, u\n",
        "        self.convs = nn.ModuleList([SpectralConv1d(cfg.width, cfg.width, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv1d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, N = u.shape\n",
        "        x_coord = torch.linspace(0, 1, N, device=u.device).unsqueeze(0).expand(B, -1)\n",
        "        x = self.lift(torch.stack([x_coord, u], dim=-1)).permute(0, 2, 1)\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 1)).squeeze(-1)\n",
        "\n",
        "\n",
        "class FNO2d(nn.Module):\n",
        "    \"\"\"2D Fourier Neural Operator.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.lift = nn.Linear(3, cfg.width)  # x, y, u\n",
        "        self.convs = nn.ModuleList([SpectralConv2d(cfg.width, cfg.width, cfg.modes, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv2d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, H, W = u.shape\n",
        "        device = u.device\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, H, 1).expand(B, H, W)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, W).expand(B, H, W)\n",
        "\n",
        "        x = self.lift(torch.stack([gx, gy, u], dim=-1)).permute(0, 3, 1, 2)\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 3, 1)).squeeze(-1)\n",
        "\n",
        "\n",
        "class FNO2d_AR(nn.Module):\n",
        "    \"\"\"2D FNO for autoregressive temporal prediction.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig, T_in: int = 10):\n",
        "        super().__init__()\n",
        "        self.T_in = T_in\n",
        "        self.lift = nn.Linear(T_in + 2, cfg.width)\n",
        "        self.convs = nn.ModuleList([SpectralConv2d(cfg.width, cfg.width, cfg.modes, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv2d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u_history: torch.Tensor) -> torch.Tensor:\n",
        "        # u_history: [B, T_in, H, W]\n",
        "        B, T, H, W = u_history.shape\n",
        "        device = u_history.device\n",
        "\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, 1, H, 1).expand(B, 1, H, W)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, 1, W).expand(B, 1, H, W)\n",
        "\n",
        "        x = torch.cat([u_history, gx, gy], dim=1).permute(0, 2, 3, 1)  # [B, H, W, T+2]\n",
        "        x = self.lift(x).permute(0, 3, 1, 2)  # [B, width, H, W]\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 3, 1)).squeeze(-1)\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 34,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DR2IxBr6S25g",
        "cellView": "form"
      },
      "source": [
        "# @title 12. DEL-FNO MODELS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class DELFNO1d(nn.Module):\n",
        "    \"\"\"1D DEL-FNO with learnable fractional derivatives.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.features = SpectraFeatures1D(cfg)\n",
        "        self.lift = nn.Linear(self.features.n_channels, cfg.width)\n",
        "        self.convs = nn.ModuleList([SpectralConv1d(cfg.width, cfg.width, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv1d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        x = self.features(u)  # [B, N, n_channels]\n",
        "        x = self.lift(x).permute(0, 2, 1)\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 1)).squeeze(-1)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "class DELFNO2d(nn.Module):\n",
        "    \"\"\"2D DEL-FNO.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.features = SpectraFeatures2D(cfg)\n",
        "        self.lift = nn.Linear(self.features.n_channels, cfg.width)\n",
        "        self.convs = nn.ModuleList([SpectralConv2d(cfg.width, cfg.width, cfg.modes, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv2d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        x = self.features(u)  # [B, H, W, n_channels]\n",
        "        x = self.lift(x).permute(0, 3, 1, 2)\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 3, 1)).squeeze(-1)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "class DELFNO2d_AR(nn.Module):\n",
        "    \"\"\"2D DEL-FNO for autoregressive prediction.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig, T_in: int = 10):\n",
        "        super().__init__()\n",
        "        self.T_in = T_in\n",
        "        self.features = SpectraFeatures2D(cfg)\n",
        "        in_dim = T_in + self.features.n_channels - 1 + 2  # history + features (minus u) + grid\n",
        "\n",
        "        self.lift = nn.Linear(in_dim, cfg.width)\n",
        "        self.convs = nn.ModuleList([SpectralConv2d(cfg.width, cfg.width, cfg.modes, cfg.modes) for _ in range(cfg.n_layers)])\n",
        "        self.ws = nn.ModuleList([nn.Conv2d(cfg.width, cfg.width, 1) for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u_history: torch.Tensor) -> torch.Tensor:\n",
        "        B, T, H, W = u_history.shape\n",
        "        device = u_history.device\n",
        "\n",
        "        # Features from last timestep\n",
        "        feat = self.features(u_history[:, -1])  # [B, H, W, n_feat]\n",
        "\n",
        "        # Grid\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, H, 1, 1).expand(B, H, W, 1)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, W, 1).expand(B, H, W, 1)\n",
        "\n",
        "        # Combine: history + derivative features (skip u) + grid\n",
        "        x = torch.cat([u_history.permute(0, 2, 3, 1), feat[:, :, :, 1:], gx, gy], dim=-1)\n",
        "        x = self.lift(x).permute(0, 3, 1, 2)\n",
        "\n",
        "        for i, (conv, w) in enumerate(zip(self.convs, self.ws)):\n",
        "            x1 = conv(x)\n",
        "            x2 = w(x)\n",
        "            x = F.gelu(x1 + x2) if i < len(self.convs) - 1 else x1 + x2\n",
        "\n",
        "        return self.proj(x.permute(0, 2, 3, 1)).squeeze(-1)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FK68HwuwS25g",
        "cellView": "form"
      },
      "source": [
        "# @title 13. PHYSICS-ATTENTION (thuml-compatible)\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class PhysicsAttention1D(nn.Module):\n",
        "    \"\"\"Physics-Attention for 1D data (thuml-compatible).\"\"\"\n",
        "    def __init__(self, hidden_dim: int, n_heads: int, slice_num: int, mlp_ratio: int = 2, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        self.hidden_dim = hidden_dim\n",
        "        self.n_heads = n_heads\n",
        "        self.slice_num = slice_num\n",
        "        self.head_dim = hidden_dim // n_heads\n",
        "        inner_dim = self.head_dim * n_heads\n",
        "        self.scale = self.head_dim ** -0.5\n",
        "\n",
        "        self.softmax = nn.Softmax(dim=-1)\n",
        "        self.dropout_layer = nn.Dropout(dropout)\n",
        "        self.temperature = nn.Parameter(torch.ones(1, n_heads, 1, 1) * 0.5)\n",
        "\n",
        "        self.in_project_x = nn.Linear(hidden_dim, inner_dim)\n",
        "        self.in_project_fx = nn.Linear(hidden_dim, inner_dim)\n",
        "        self.in_project_slice = nn.Linear(self.head_dim, slice_num)\n",
        "        nn.init.orthogonal_(self.in_project_slice.weight)\n",
        "\n",
        "        self.to_q = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "        self.to_k = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "        self.to_v = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "\n",
        "        self.to_out = nn.Sequential(nn.Linear(inner_dim, hidden_dim), nn.Dropout(dropout))\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(hidden_dim, hidden_dim * mlp_ratio), nn.GELU(), nn.Dropout(dropout),\n",
        "            nn.Linear(hidden_dim * mlp_ratio, hidden_dim), nn.Dropout(dropout))\n",
        "        self.norm1 = nn.LayerNorm(hidden_dim)\n",
        "        self.norm2 = nn.LayerNorm(hidden_dim)\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        B, N, C = x.shape\n",
        "\n",
        "        # Slice\n",
        "        fx_mid = self.in_project_fx(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
        "        x_mid = self.in_project_x(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
        "        slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature)\n",
        "        slice_norm = slice_weights.sum(2)\n",
        "        slice_token = torch.einsum(\"bhnc,bhng->bhgc\", fx_mid, slice_weights) / (slice_norm[:, :, :, None] + 1e-5)\n",
        "\n",
        "        # Attention\n",
        "        q, k, v = self.to_q(slice_token), self.to_k(slice_token), self.to_v(slice_token)\n",
        "        attn = self.softmax(torch.matmul(q, k.transpose(-1, -2)) * self.scale)\n",
        "        out_slice_token = torch.matmul(self.dropout_layer(attn), v)\n",
        "\n",
        "        # Deslice\n",
        "        out_x = torch.einsum(\"bhgc,bhng->bhnc\", out_slice_token, slice_weights)\n",
        "        out_x = out_x.permute(0, 2, 1, 3).reshape(B, N, self.n_heads * self.head_dim)\n",
        "\n",
        "        x = self.norm1(x + self.to_out(out_x))\n",
        "        return self.norm2(x + self.mlp(x))\n",
        "\n",
        "\n",
        "class PhysicsAttention2D(nn.Module):\n",
        "    \"\"\"Physics-Attention for 2D data (thuml-compatible).\"\"\"\n",
        "    def __init__(self, hidden_dim: int, n_heads: int, slice_num: int, mlp_ratio: int = 2, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        self.hidden_dim = hidden_dim\n",
        "        self.n_heads = n_heads\n",
        "        self.slice_num = slice_num\n",
        "        self.head_dim = hidden_dim // n_heads\n",
        "        inner_dim = self.head_dim * n_heads\n",
        "        self.scale = self.head_dim ** -0.5\n",
        "\n",
        "        self.softmax = nn.Softmax(dim=-1)\n",
        "        self.dropout_layer = nn.Dropout(dropout)\n",
        "        self.temperature = nn.Parameter(torch.ones(1, n_heads, 1, 1) * 0.5)\n",
        "\n",
        "        self.in_project_x = nn.Linear(hidden_dim, inner_dim)\n",
        "        self.in_project_fx = nn.Linear(hidden_dim, inner_dim)\n",
        "        self.in_project_slice = nn.Linear(self.head_dim, slice_num)\n",
        "        nn.init.orthogonal_(self.in_project_slice.weight)\n",
        "\n",
        "        self.to_q = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "        self.to_k = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "        self.to_v = nn.Linear(self.head_dim, self.head_dim, bias=False)\n",
        "\n",
        "        self.to_out = nn.Sequential(nn.Linear(inner_dim, hidden_dim), nn.Dropout(dropout))\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(hidden_dim, hidden_dim * mlp_ratio), nn.GELU(), nn.Dropout(dropout),\n",
        "            nn.Linear(hidden_dim * mlp_ratio, hidden_dim), nn.Dropout(dropout))\n",
        "        self.norm1 = nn.LayerNorm(hidden_dim)\n",
        "        self.norm2 = nn.LayerNorm(hidden_dim)\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        B, N, C = x.shape\n",
        "\n",
        "        fx_mid = self.in_project_fx(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
        "        x_mid = self.in_project_x(x).reshape(B, N, self.n_heads, self.head_dim).permute(0, 2, 1, 3)\n",
        "        slice_weights = self.softmax(self.in_project_slice(x_mid) / torch.clamp(self.temperature, 0.1, 5))\n",
        "        slice_norm = slice_weights.sum(2)\n",
        "        slice_token = torch.einsum(\"bhnc,bhng->bhgc\", fx_mid, slice_weights) / (slice_norm[:, :, :, None] + 1e-5)\n",
        "\n",
        "        q, k, v = self.to_q(slice_token), self.to_k(slice_token), self.to_v(slice_token)\n",
        "        attn = self.softmax(torch.matmul(q, k.transpose(-1, -2)) * self.scale)\n",
        "        out_slice_token = torch.matmul(self.dropout_layer(attn), v)\n",
        "\n",
        "        out_x = torch.einsum(\"bhgc,bhng->bhnc\", out_slice_token, slice_weights)\n",
        "        out_x = out_x.permute(0, 2, 1, 3).reshape(B, N, self.n_heads * self.head_dim)\n",
        "\n",
        "        x = self.norm1(x + self.to_out(out_x))\n",
        "        return self.norm2(x + self.mlp(x))\n",
        "\n",
        "\n",
        "class TransolverBlock(nn.Module):\n",
        "    \"\"\"Single Transolver block.\"\"\"\n",
        "    def __init__(self, hidden_dim: int, n_heads: int, slice_num: int, mlp_ratio: int, dropout: float, dim: int = 1):\n",
        "        super().__init__()\n",
        "        AttnClass = PhysicsAttention1D if dim == 1 else PhysicsAttention2D\n",
        "        self.attn = AttnClass(hidden_dim, n_heads, slice_num, mlp_ratio, dropout)\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        return self.attn(x)\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 36,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MyoLXxElS25h",
        "cellView": "form"
      },
      "source": [
        "# @title 14. TRANSOLVER MODELS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class Transolver1D(nn.Module):\n",
        "    \"\"\"1D Transolver.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.lift = nn.Linear(2, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=1)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, N = u.shape\n",
        "        x_coord = torch.linspace(0, 1, N, device=u.device).unsqueeze(0).expand(B, -1)\n",
        "        x = self.lift(torch.stack([x_coord, u], dim=-1))\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1)\n",
        "\n",
        "\n",
        "class Transolver2D(nn.Module):\n",
        "    \"\"\"2D Transolver.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.lift = nn.Linear(3, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=2)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, H, W = u.shape\n",
        "        device = u.device\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, H, 1).expand(B, H, W)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, W).expand(B, H, W)\n",
        "        x = self.lift(torch.stack([gx, gy, u], dim=-1).reshape(B, H * W, 3))\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1).reshape(B, H, W)\n",
        "\n",
        "\n",
        "class Transolver2D_AR(nn.Module):\n",
        "    \"\"\"2D Transolver for autoregressive prediction.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig, T_in: int = 10):\n",
        "        super().__init__()\n",
        "        self.T_in = T_in\n",
        "        self.lift = nn.Linear(T_in + 2, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=2)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u_history: torch.Tensor) -> torch.Tensor:\n",
        "        B, T, H, W = u_history.shape\n",
        "        device = u_history.device\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, H, 1).expand(B, H, W)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, W).expand(B, H, W)\n",
        "        grid = torch.stack([gx, gy], dim=1)\n",
        "\n",
        "        x = torch.cat([u_history, grid], dim=1).permute(0, 2, 3, 1).reshape(B, H * W, T + 2)\n",
        "        x = self.lift(x)\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1).reshape(B, H, W)\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 37,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5QxyTi_US25h",
        "cellView": "form"
      },
      "source": [
        "# @title 15. DEL-TRANSOLVER MODELS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "class DELTransolver1D(nn.Module):\n",
        "    \"\"\"1D DEL-Transolver.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.features = SpectraFeatures1D(cfg)\n",
        "        self.lift = nn.Linear(self.features.n_channels, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=1)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        x = self.lift(self.features(u))\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "class DELTransolver2D(nn.Module):\n",
        "    \"\"\"2D DEL-Transolver.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig):\n",
        "        super().__init__()\n",
        "        self.features = SpectraFeatures2D(cfg)\n",
        "        self.lift = nn.Linear(self.features.n_channels, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=2)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u: torch.Tensor) -> torch.Tensor:\n",
        "        B, H, W = u.shape\n",
        "        x = self.lift(self.features(u).reshape(B, H * W, -1))\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1).reshape(B, H, W)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "class DELTransolver2D_AR(nn.Module):\n",
        "    \"\"\"2D DEL-Transolver for autoregressive prediction.\"\"\"\n",
        "    def __init__(self, cfg: SpectraConfig, T_in: int = 10):\n",
        "        super().__init__()\n",
        "        self.T_in = T_in\n",
        "        self.features = SpectraFeatures2D(cfg)\n",
        "        in_dim = T_in + self.features.n_channels - 1 + 2\n",
        "\n",
        "        self.lift = nn.Linear(in_dim, cfg.width)\n",
        "        self.blocks = nn.ModuleList([\n",
        "            TransolverBlock(cfg.width, cfg.n_heads, cfg.slice_num, cfg.mlp_ratio, cfg.dropout, dim=2)\n",
        "            for _ in range(cfg.n_layers)])\n",
        "        self.proj = nn.Sequential(nn.Linear(cfg.width, 128), nn.GELU(), nn.Linear(128, 1))\n",
        "\n",
        "    def forward(self, u_history: torch.Tensor) -> torch.Tensor:\n",
        "        B, T, H, W = u_history.shape\n",
        "        device = u_history.device\n",
        "\n",
        "        feat = self.features(u_history[:, -1])\n",
        "        gx = torch.linspace(0, 1, H, device=device).view(1, H, 1, 1).expand(B, H, W, 1)\n",
        "        gy = torch.linspace(0, 1, W, device=device).view(1, 1, W, 1).expand(B, H, W, 1)\n",
        "\n",
        "        x = torch.cat([u_history.permute(0, 2, 3, 1), feat[:, :, :, 1:], gx, gy], dim=-1)\n",
        "        x = self.lift(x.reshape(B, H * W, -1))\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        return self.proj(x).squeeze(-1).reshape(B, H, W)\n",
        "\n",
        "    def get_params(self): return self.features.get_params()\n",
        "    def beta_parameters(self): return self.features.beta_parameters()\n",
        "    def heff_parameters(self): return self.features.heff_parameters()\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TusvDeaAS25h",
        "cellView": "form"
      },
      "source": [
        "# @title 16. MODEL FACTORY\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def create_model(cfg: SpectraConfig, use_del: bool = False, T_in: int = 10) -> nn.Module:\n",
        "    \"\"\"Create model based on configuration.\"\"\"\n",
        "    dataset = cfg.dataset.lower()\n",
        "    model_type = cfg.model_type.lower()\n",
        "    is_1d = 'burgers' in dataset\n",
        "    is_temporal = 'ns' in dataset\n",
        "\n",
        "    if model_type == 'fno':\n",
        "        if is_1d:\n",
        "            return DELFNO1d(cfg) if use_del else FNO1d(cfg)\n",
        "        elif is_temporal:\n",
        "            return DELFNO2d_AR(cfg, T_in) if use_del else FNO2d_AR(cfg, T_in)\n",
        "        else:\n",
        "            return DELFNO2d(cfg) if use_del else FNO2d(cfg)\n",
        "    elif model_type == 'transolver':\n",
        "        if is_1d:\n",
        "            return DELTransolver1D(cfg) if use_del else Transolver1D(cfg)\n",
        "        elif is_temporal:\n",
        "            return DELTransolver2D_AR(cfg, T_in) if use_del else Transolver2D_AR(cfg, T_in)\n",
        "        else:\n",
        "            return DELTransolver2D(cfg) if use_del else Transolver2D(cfg)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown model type: {model_type}\")\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ae20wL1KS25h",
        "cellView": "form"
      },
      "source": [
        "# @title 17. DATA LOADERS\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def download_gdrive(file_id: str, output: str):\n",
        "    \"\"\"Download file from Google Drive.\"\"\"\n",
        "    if not HAS_GDOWN:\n",
        "        raise ImportError(\"Install gdown: pip install gdown\")\n",
        "    gdown.download(f'https://drive.google.com/uc?id={file_id}', output, quiet=False)\n",
        "\n",
        "\n",
        "def load_burgers(cfg: SpectraConfig, device: torch.device, data_dir: str = './data') -> Dict:\n",
        "    \"\"\"Load Burgers R10 data.\"\"\"\n",
        "    os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "    # Try to find existing file\n",
        "    patterns = [\n",
        "        f'{data_dir}/burgers_data_R10.mat',\n",
        "        f'{data_dir}/Burgers_R10*.mat',\n",
        "        f'{data_dir}/*burgers*.mat',\n",
        "    ]\n",
        "\n",
        "    filepath = None\n",
        "    for pattern in patterns:\n",
        "        if '*' in pattern:\n",
        "            files = glob.glob(pattern)\n",
        "            if files:\n",
        "                filepath = files[0]\n",
        "                break\n",
        "        elif os.path.exists(pattern):\n",
        "            filepath = pattern\n",
        "            break\n",
        "\n",
        "    if filepath is None and GDRIVE_IDS.get('burgers_r10'):\n",
        "        print(\"Downloading Burgers R10 data...\")\n",
        "        filepath = f'{data_dir}/burgers_r10.mat'\n",
        "        download_gdrive(GDRIVE_IDS['burgers_r10'], filepath)\n",
        "\n",
        "    if filepath is None:\n",
        "        raise FileNotFoundError(\"Could not find Burgers R10 data\")\n",
        "\n",
        "    # Load\n",
        "    try:\n",
        "        import h5py\n",
        "        with h5py.File(filepath, 'r') as f:\n",
        "            a = np.array(f['a']).T.astype(np.float32)\n",
        "            u = np.array(f['u']).T.astype(np.float32)\n",
        "    except:\n",
        "        mat = scipy.io.loadmat(filepath)\n",
        "        a, u = mat['a'].astype(np.float32), mat['u'].astype(np.float32)\n",
        "\n",
        "    a, u = torch.from_numpy(a), torch.from_numpy(u)\n",
        "\n",
        "    # Subsample\n",
        "    N_orig = a.shape[-1]\n",
        "    if cfg.resolution < N_orig:\n",
        "        step = N_orig // cfg.resolution\n",
        "        a = a[:, ::step][:, :cfg.resolution]\n",
        "        u = u[:, ::step][:, :cfg.resolution]\n",
        "\n",
        "    return {\n",
        "        'a_train': a[:cfg.n_train].to(device),\n",
        "        'u_train': u[:cfg.n_train].to(device),\n",
        "        'a_test': a[cfg.n_train:cfg.n_train + cfg.n_test].to(device),\n",
        "        'u_test': u[cfg.n_train:cfg.n_train + cfg.n_test].to(device),\n",
        "        'N': a.shape[-1],\n",
        "    }\n",
        "\n",
        "\n",
        "def load_burgers_r1000(cfg: SpectraConfig, device: torch.device, data_dir: str = './data') -> Dict:\n",
        "    \"\"\"Load Burgers R1000 data.\"\"\"\n",
        "    os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "    patterns = [\n",
        "        f'{data_dir}/burgers_v1000*.mat',\n",
        "        f'{data_dir}/Burgers*1000*.mat',\n",
        "        f'{data_dir}/burgers_r1000*.mat',\n",
        "    ]\n",
        "\n",
        "    filepath = None\n",
        "    for pattern in patterns:\n",
        "        files = glob.glob(pattern)\n",
        "        if files:\n",
        "            filepath = files[0]\n",
        "            break\n",
        "\n",
        "    if filepath is None and GDRIVE_IDS.get('burgers_r1000'):\n",
        "        print(\"Downloading Burgers R1000 data...\")\n",
        "        filepath = f'{data_dir}/burgers_r1000.mat'\n",
        "        download_gdrive(GDRIVE_IDS['burgers_r1000'], filepath)\n",
        "\n",
        "    if filepath is None:\n",
        "        raise FileNotFoundError(\"Could not find Burgers R1000 data\")\n",
        "\n",
        "    # Load\n",
        "    try:\n",
        "        import h5py\n",
        "        with h5py.File(filepath, 'r') as f:\n",
        "            keys = list(f.keys())\n",
        "            if 'input' in f:\n",
        "                a = np.array(f['input']).astype(np.float32)\n",
        "                u = np.array(f['output']).astype(np.float32)\n",
        "                if u.ndim == 3:\n",
        "                    u = u[:, -1, :] if u.shape[1] < u.shape[2] else u[-1, :, :]\n",
        "            else:\n",
        "                a = np.array(f['a']).T.astype(np.float32)\n",
        "                u = np.array(f['u']).T.astype(np.float32)\n",
        "    except:\n",
        "        mat = scipy.io.loadmat(filepath)\n",
        "        a, u = mat.get('a', mat.get('input')).astype(np.float32), mat.get('u', mat.get('output')).astype(np.float32)\n",
        "        if u.ndim == 3:\n",
        "            u = u[:, -1, :]\n",
        "\n",
        "    a, u = torch.from_numpy(a), torch.from_numpy(u)\n",
        "\n",
        "    N_orig = a.shape[-1]\n",
        "    if cfg.resolution < N_orig:\n",
        "        step = N_orig // cfg.resolution\n",
        "        a = a[:, ::step][:, :cfg.resolution]\n",
        "        u = u[:, ::step][:, :cfg.resolution]\n",
        "\n",
        "    n_train = min(cfg.n_train, a.shape[0] - cfg.n_test)\n",
        "\n",
        "    return {\n",
        "        'a_train': a[:n_train].to(device),\n",
        "        'u_train': u[:n_train].to(device),\n",
        "        'a_test': a[n_train:n_train + cfg.n_test].to(device),\n",
        "        'u_test': u[n_train:n_train + cfg.n_test].to(device),\n",
        "        'N': a.shape[-1],\n",
        "    }\n",
        "\n",
        "\n",
        "def load_darcy(cfg: SpectraConfig, device: torch.device, data_dir: str = './data', use_normalizer: bool = True) -> Dict:\n",
        "    \"\"\"Load Darcy flow data.\"\"\"\n",
        "    os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "    train_path = f'{data_dir}/piececonst_r421_N1024_smooth1.mat'\n",
        "    test_path = f'{data_dir}/piececonst_r421_N1024_smooth2.mat'\n",
        "\n",
        "    if not os.path.exists(train_path) and GDRIVE_IDS.get('darcy_421'):\n",
        "        print(\"Downloading Darcy data...\")\n",
        "        download_gdrive(GDRIVE_IDS['darcy_421'], f'{data_dir}/darcy.mat')\n",
        "        # May need to handle zip extraction here\n",
        "\n",
        "    # Try loading\n",
        "    try:\n",
        "        import h5py\n",
        "        with h5py.File(train_path, 'r') as f:\n",
        "            a_train = np.array(f['coeff']).T.astype(np.float32)\n",
        "            u_train = np.array(f['sol']).T.astype(np.float32)\n",
        "        with h5py.File(test_path, 'r') as f:\n",
        "            a_test = np.array(f['coeff']).T.astype(np.float32)\n",
        "            u_test = np.array(f['sol']).T.astype(np.float32)\n",
        "    except:\n",
        "        mat = scipy.io.loadmat(train_path)\n",
        "        a_train, u_train = mat['coeff'].astype(np.float32), mat['sol'].astype(np.float32)\n",
        "        mat = scipy.io.loadmat(test_path)\n",
        "        a_test, u_test = mat['coeff'].astype(np.float32), mat['sol'].astype(np.float32)\n",
        "\n",
        "    # Downsample 421 → 85\n",
        "    s = 5\n",
        "    a_train = torch.from_numpy(a_train[:cfg.n_train, ::s, ::s])\n",
        "    u_train = torch.from_numpy(u_train[:cfg.n_train, ::s, ::s])\n",
        "    a_test = torch.from_numpy(a_test[:cfg.n_test, ::s, ::s])\n",
        "    u_test = torch.from_numpy(u_test[:cfg.n_test, ::s, ::s])\n",
        "\n",
        "    H, W = a_train.shape[1], a_train.shape[2]\n",
        "\n",
        "    data = {\n",
        "        'a_train': a_train.to(device), 'u_train': u_train.to(device),\n",
        "        'a_test': a_test.to(device), 'u_test': u_test.to(device),\n",
        "        'H': H, 'W': W, 'y_normalizer': None,\n",
        "    }\n",
        "\n",
        "    if use_normalizer:\n",
        "        y_norm = UnitGaussianNormalizer(u_train).to(device)\n",
        "        data['y_normalizer'] = y_norm\n",
        "        data['u_train_norm'] = y_norm.encode(data['u_train'])\n",
        "\n",
        "    return data\n",
        "\n",
        "\n",
        "def load_navier_stokes(cfg: SpectraConfig, device: torch.device, variant: str = 'v1e3',\n",
        "                       T_in: int = 10, T_out: int = 10, data_dir: str = './data') -> Dict:\n",
        "    \"\"\"Load Navier-Stokes vorticity data for autoregressive training.\"\"\"\n",
        "    os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "    if variant == 'v1e3':\n",
        "        patterns = [f'{data_dir}/ns_V1e-3*.mat', f'{data_dir}/NavierStokes_V1e-3*.mat', f'{data_dir}/*V1e-3*.mat']\n",
        "        gdrive_key = 'ns_v1e3'\n",
        "    else:\n",
        "        patterns = [f'{data_dir}/NavierStokes_V1e-5*.mat', f'{data_dir}/*V1e-5*.mat']\n",
        "        gdrive_key = 'ns_v1e5'\n",
        "\n",
        "    filepath = None\n",
        "    for pattern in patterns:\n",
        "        files = glob.glob(pattern)\n",
        "        if files:\n",
        "            filepath = files[0]\n",
        "            break\n",
        "\n",
        "    if filepath is None and GDRIVE_IDS.get(gdrive_key):\n",
        "        print(f\"Downloading NS {variant} data...\")\n",
        "        filepath = f'{data_dir}/ns_{variant}.mat'\n",
        "        download_gdrive(GDRIVE_IDS[gdrive_key], filepath)\n",
        "\n",
        "    if filepath is None:\n",
        "        raise FileNotFoundError(f\"Could not find NS {variant} data\")\n",
        "\n",
        "    # Load\n",
        "    try:\n",
        "        import h5py\n",
        "        with h5py.File(filepath, 'r') as f:\n",
        "            keys = list(f.keys())\n",
        "            data = np.array(f.get('u', f.get('vorticity', f[keys[0]]))).astype(np.float32)\n",
        "    except:\n",
        "        mat = scipy.io.loadmat(filepath)\n",
        "        keys = [k for k in mat.keys() if not k.startswith('_')]\n",
        "        data = mat.get('u', mat.get('vorticity', mat[keys[0]])).astype(np.float32)\n",
        "\n",
        "    # Reshape to [N, T, H, W]\n",
        "    if data.ndim == 4:\n",
        "        dims = data.shape\n",
        "        if dims[-1] < dims[1] and dims[-1] < 100:\n",
        "            data = np.transpose(data, (0, 3, 1, 2))\n",
        "        elif dims[0] < 100 and dims[-1] > 1000:\n",
        "            data = np.transpose(data, (3, 0, 1, 2))\n",
        "\n",
        "    N_total, T_total, H_orig, W_orig = data.shape\n",
        "\n",
        "    # Subsample spatial\n",
        "    resolution = cfg.resolution or 64\n",
        "    if resolution < H_orig:\n",
        "        step = H_orig // resolution\n",
        "        data = data[:, :, ::step, ::step][:, :, :resolution, :resolution]\n",
        "\n",
        "    H, W = resolution, resolution\n",
        "\n",
        "    if T_in + T_out > T_total:\n",
        "        T_out = T_total - T_in\n",
        "\n",
        "    a = torch.from_numpy(data[:, :T_in])\n",
        "    u = torch.from_numpy(data[:, T_in:T_in + T_out])\n",
        "\n",
        "    n_train = min(cfg.n_train, a.shape[0] - cfg.n_test)\n",
        "    n_test = min(cfg.n_test, a.shape[0] - n_train)\n",
        "\n",
        "    return {\n",
        "        'a_train': a[:n_train].to(device), 'u_train': u[:n_train].to(device),\n",
        "        'a_test': a[n_train:n_train + n_test].to(device), 'u_test': u[n_train:n_train + n_test].to(device),\n",
        "        'H': H, 'W': W, 'T_in': T_in, 'T_out': T_out,\n",
        "    }\n",
        "\n",
        "\n",
        "def load_data(cfg: SpectraConfig, device: torch.device, data_dir: str = './data') -> Dict:\n",
        "    \"\"\"Load data based on config.\"\"\"\n",
        "    dataset = cfg.dataset.lower()\n",
        "\n",
        "    if dataset == 'burgers_r10':\n",
        "        return load_burgers(cfg, device, data_dir)\n",
        "    elif dataset == 'burgers_r1000':\n",
        "        return load_burgers_r1000(cfg, device, data_dir)\n",
        "    elif dataset == 'darcy':\n",
        "        return load_darcy(cfg, device, data_dir)\n",
        "    elif dataset in ['ns_v1e3', 'ns_v1e5']:\n",
        "        variant = 'v1e3' if 'v1e3' in dataset else 'v1e5'\n",
        "        return load_navier_stokes(cfg, device, variant, data_dir=data_dir)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown dataset: {dataset}\")\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6u8nW3H1S25h",
        "cellView": "form"
      },
      "source": [
        "# @title 18. TRAINING\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def get_optimizer(model: nn.Module, cfg: SpectraConfig) -> torch.optim.Optimizer:\n",
        "    \"\"\"Get optimizer with separate parameter groups for beta and h_eff.\"\"\"\n",
        "    if hasattr(model, 'beta_parameters') and len(list(model.beta_parameters())) > 0:\n",
        "        beta_params = list(model.beta_parameters())\n",
        "        heff_params = list(model.heff_parameters())\n",
        "        special_ids = set(id(p) for p in beta_params + heff_params)\n",
        "        backbone_params = [p for p in model.parameters() if id(p) not in special_ids]\n",
        "\n",
        "        return torch.optim.AdamW([\n",
        "            {'params': backbone_params, 'lr': cfg.lr, 'weight_decay': cfg.weight_decay},\n",
        "            {'params': beta_params, 'lr': cfg.lr_beta, 'weight_decay': 0.0},\n",
        "            {'params': heff_params, 'lr': cfg.lr_heff, 'weight_decay': 0.0},\n",
        "        ])\n",
        "    return torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)\n",
        "\n",
        "\n",
        "def get_scheduler(optimizer: torch.optim.Optimizer, cfg: SpectraConfig, steps_per_epoch: int):\n",
        "    \"\"\"Get learning rate scheduler.\"\"\"\n",
        "    if cfg.scheduler == 'onecycle':\n",
        "        return torch.optim.lr_scheduler.OneCycleLR(\n",
        "            optimizer, max_lr=cfg.lr, epochs=cfg.epochs, steps_per_epoch=steps_per_epoch,\n",
        "            pct_start=cfg.pct_start, div_factor=25.0, final_div_factor=1e4)\n",
        "    elif cfg.scheduler == 'step':\n",
        "        return torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.step_size, gamma=cfg.gamma)\n",
        "    elif cfg.scheduler == 'cosine':\n",
        "        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.epochs, eta_min=1e-6)\n",
        "    raise ValueError(f\"Unknown scheduler: {cfg.scheduler}\")\n",
        "\n",
        "\n",
        "def train_1d(model: nn.Module, data: Dict, cfg: SpectraConfig, device: torch.device, name: str = '') -> Dict:\n",
        "    \"\"\"Train 1D model.\"\"\"\n",
        "    model = model.to(device)\n",
        "    a_train, u_train = data['a_train'], data['u_train']\n",
        "    a_test, u_test = data['a_test'], data['u_test']\n",
        "\n",
        "    steps_per_epoch = max(1, len(a_train) // cfg.batch_size)\n",
        "    is_onecycle = cfg.scheduler == 'onecycle'\n",
        "\n",
        "    optimizer = get_optimizer(model, cfg)\n",
        "    scheduler = get_scheduler(optimizer, cfg, steps_per_epoch)\n",
        "\n",
        "    print(f\"\\n{'─'*60}\\n{name} | {count_params(model)/1e3:.1f}K params\\n{'─'*60}\")\n",
        "\n",
        "    best_loss = float('inf')\n",
        "    t0 = time.time()\n",
        "\n",
        "    for epoch in range(1, cfg.epochs + 1):\n",
        "        model.train()\n",
        "        perm = torch.randperm(len(a_train), device=device)\n",
        "        epoch_loss, n_batches = 0.0, 0\n",
        "\n",
        "        for i in range(0, len(a_train), cfg.batch_size):\n",
        "            idx = perm[i:i + cfg.batch_size]\n",
        "            pred = model(a_train[idx])\n",
        "            loss = rel_l2_loss(pred, u_train[idx])\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)\n",
        "            optimizer.step()\n",
        "\n",
        "            if is_onecycle:\n",
        "                scheduler.step()\n",
        "\n",
        "            epoch_loss += loss.item()\n",
        "            n_batches += 1\n",
        "\n",
        "        if not is_onecycle:\n",
        "            scheduler.step()\n",
        "\n",
        "        train_loss = epoch_loss / n_batches\n",
        "\n",
        "        if epoch % cfg.eval_every == 0 or epoch == 1 or epoch == cfg.epochs:\n",
        "            model.eval()\n",
        "            with torch.no_grad():\n",
        "                test_loss = rel_l2_loss(model(a_test), u_test).item()\n",
        "\n",
        "            is_best = test_loss < best_loss\n",
        "            best_loss = min(best_loss, test_loss)\n",
        "\n",
        "            params = model.get_params() if hasattr(model, 'get_params') else {}\n",
        "            star = \" ★\" if is_best else \"\"\n",
        "            param_str = \" | \" + \", \".join([f\"β={b:.2f}\" for b in params.get('betas', [])[:4]]) if params else \"\"\n",
        "            print(f\"  Ep {epoch:4d}: Train={100*train_loss:.3f}% Test={100*test_loss:.3f}% ({time.time()-t0:.0f}s){star}{param_str}\")\n",
        "\n",
        "    return {'best_loss': best_loss, 'params': model.get_params() if hasattr(model, 'get_params') else {}}\n",
        "\n",
        "\n",
        "def train_2d(model: nn.Module, data: Dict, cfg: SpectraConfig, device: torch.device, name: str = '') -> Dict:\n",
        "    \"\"\"Train 2D model with optional normalization.\"\"\"\n",
        "    model = model.to(device)\n",
        "    a_train, u_train = data['a_train'], data['u_train']\n",
        "    a_test, u_test = data['a_test'], data['u_test']\n",
        "\n",
        "    y_norm = data.get('y_normalizer')\n",
        "    u_train_target = data.get('u_train_norm', u_train) if y_norm else u_train\n",
        "\n",
        "    steps_per_epoch = max(1, len(a_train) // cfg.batch_size)\n",
        "    is_onecycle = cfg.scheduler == 'onecycle'\n",
        "\n",
        "    optimizer = get_optimizer(model, cfg)\n",
        "    scheduler = get_scheduler(optimizer, cfg, steps_per_epoch)\n",
        "\n",
        "    norm_str = \" [normalized]\" if y_norm else \"\"\n",
        "    print(f\"\\n{'─'*60}\\n{name} | {count_params(model)/1e3:.1f}K params{norm_str}\\n{'─'*60}\")\n",
        "\n",
        "    best_loss = float('inf')\n",
        "    t0 = time.time()\n",
        "\n",
        "    for epoch in range(1, cfg.epochs + 1):\n",
        "        model.train()\n",
        "        perm = torch.randperm(len(a_train), device=device)\n",
        "        epoch_loss, n_batches = 0.0, 0\n",
        "\n",
        "        for i in range(0, len(a_train), cfg.batch_size):\n",
        "            idx = perm[i:i + cfg.batch_size]\n",
        "            pred = model(a_train[idx])\n",
        "            loss = rel_l2_loss(pred, u_train_target[idx])\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)\n",
        "            optimizer.step()\n",
        "\n",
        "            if is_onecycle:\n",
        "                scheduler.step()\n",
        "\n",
        "            epoch_loss += loss.item()\n",
        "            n_batches += 1\n",
        "\n",
        "        if not is_onecycle:\n",
        "            scheduler.step()\n",
        "\n",
        "        train_loss = epoch_loss / n_batches\n",
        "\n",
        "        if epoch % cfg.eval_every == 0 or epoch == 1 or epoch == cfg.epochs:\n",
        "            model.eval()\n",
        "            with torch.no_grad():\n",
        "                pred_test = model(a_test)\n",
        "                if y_norm:\n",
        "                    pred_test = y_norm.decode(pred_test)\n",
        "                test_loss = rel_l2_loss(pred_test, u_test).item()\n",
        "\n",
        "            is_best = test_loss < best_loss\n",
        "            best_loss = min(best_loss, test_loss)\n",
        "\n",
        "            params = model.get_params() if hasattr(model, 'get_params') else {}\n",
        "            star = \" ★\" if is_best else \"\"\n",
        "            param_str = \" | \" + \", \".join([f\"β={b:.2f}\" for b in params.get('betas', [])[:4]]) if params else \"\"\n",
        "            print(f\"  Ep {epoch:4d}: Train={100*train_loss:.3f}% Test={100*test_loss:.3f}% ({time.time()-t0:.0f}s){star}{param_str}\")\n",
        "\n",
        "    return {'best_loss': best_loss, 'params': model.get_params() if hasattr(model, 'get_params') else {}}\n",
        "\n",
        "\n",
        "def train_2d_ar(model: nn.Module, data: Dict, cfg: SpectraConfig, device: torch.device,\n",
        "                name: str = '', T_in: int = 10, T_out: int = 10) -> Dict:\n",
        "    \"\"\"Train 2D model with autoregressive single-step prediction.\"\"\"\n",
        "    model = model.to(device)\n",
        "\n",
        "    a_train, u_train = data['a_train'], data['u_train']\n",
        "    a_test, u_test = data['a_test'], data['u_test']\n",
        "\n",
        "    data_train = torch.cat([a_train, u_train], dim=1)\n",
        "    data_test = torch.cat([a_test, u_test], dim=1)\n",
        "\n",
        "    T_total = data_train.shape[1]\n",
        "    n_pairs = T_total - T_in\n",
        "\n",
        "    steps_per_epoch = max(1, len(data_train) * n_pairs // cfg.batch_size)\n",
        "    is_onecycle = cfg.scheduler == 'onecycle'\n",
        "\n",
        "    optimizer = get_optimizer(model, cfg)\n",
        "    scheduler = get_scheduler(optimizer, cfg, steps_per_epoch)\n",
        "\n",
        "    print(f\"\\n{'─'*60}\\n{name} | {count_params(model)/1e3:.1f}K | AR: T_in={T_in}→rollout {T_out}\\n{'─'*60}\")\n",
        "\n",
        "    best_loss = float('inf')\n",
        "    t0 = time.time()\n",
        "\n",
        "    for epoch in range(1, cfg.epochs + 1):\n",
        "        model.train()\n",
        "        perm = torch.randperm(len(data_train), device=device)\n",
        "        epoch_loss, n_batches = 0.0, 0\n",
        "\n",
        "        for i in range(0, len(data_train), cfg.batch_size):\n",
        "            idx = perm[i:i + cfg.batch_size]\n",
        "            batch = data_train[idx]\n",
        "            B_actual = len(idx)\n",
        "\n",
        "            t_starts = torch.randint(0, n_pairs, (B_actual,), device=device)\n",
        "            histories = torch.stack([batch[b, t:t+T_in] for b, t in enumerate(t_starts)])\n",
        "            targets = torch.stack([batch[b, t+T_in] for b, t in enumerate(t_starts)])\n",
        "\n",
        "            pred = model(histories)\n",
        "            loss = rel_l2_loss(pred, targets)\n",
        "\n",
        "            if torch.isnan(loss):\n",
        "                return {'best_loss': float('inf'), 'params': {}}\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)\n",
        "            optimizer.step()\n",
        "\n",
        "            if is_onecycle:\n",
        "                scheduler.step()\n",
        "\n",
        "            epoch_loss += loss.item()\n",
        "            n_batches += 1\n",
        "\n",
        "        if not is_onecycle:\n",
        "            scheduler.step()\n",
        "\n",
        "        train_loss = epoch_loss / max(n_batches, 1)\n",
        "\n",
        "        if epoch % cfg.eval_every == 0 or epoch == 1 or epoch == cfg.epochs:\n",
        "            model.eval()\n",
        "            with torch.no_grad():\n",
        "                # Autoregressive rollout\n",
        "                current = a_test.clone()\n",
        "                preds = []\n",
        "                for _ in range(T_out):\n",
        "                    next_step = model(current)\n",
        "                    preds.append(next_step)\n",
        "                    current = torch.cat([current[:, 1:], next_step.unsqueeze(1)], dim=1)\n",
        "                pred_traj = torch.stack(preds, dim=1)\n",
        "                test_loss = rel_l2_loss(pred_traj, u_test).item()\n",
        "\n",
        "            is_best = test_loss < best_loss\n",
        "            best_loss = min(best_loss, test_loss)\n",
        "\n",
        "            params = model.get_params() if hasattr(model, 'get_params') else {}\n",
        "            star = \" ★\" if is_best else \"\"\n",
        "            param_str = \" | \" + \", \".join([f\"β={b:.2f}\" for b in params.get('betas', [])[:4]]) if params else \"\"\n",
        "            print(f\"  Ep {epoch:4d}: Train={100*train_loss:.3f}% Test={100*test_loss:.3f}% ({time.time()-t0:.0f}s){star}{param_str}\")\n",
        "\n",
        "    return {'best_loss': best_loss, 'params': model.get_params() if hasattr(model, 'get_params') else {}}\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J7BxQYMNS25i",
        "cellView": "form"
      },
      "source": [
        "# @title 19. EXPERIMENT RUNNER\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════\n",
        "def run_experiment(dataset: str, model_type: str, preset: str, use_del: bool,\n",
        "                   device: torch.device, data_dir: str = './data') -> Dict:\n",
        "    \"\"\"\n",
        "    Run a single training experiment.\n",
        "\n",
        "    Args:\n",
        "        dataset: 'burgers_r10', 'burgers_r1000', 'darcy', 'ns_v1e3', 'ns_v1e5'\n",
        "        model_type: 'fno' or 'transolver'\n",
        "        preset: 'original', 'thuml', 'custom'\n",
        "        use_del: Whether to use DEL-enhanced model\n",
        "        device: torch device\n",
        "        data_dir: Directory containing data files\n",
        "\n",
        "    Returns:\n",
        "        Dict with 'best_loss', 'params', 'config'\n",
        "    \"\"\"\n",
        "    cfg = get_config(dataset, model_type, preset)\n",
        "    set_seed(cfg.seed)\n",
        "\n",
        "    # Determine experiment type\n",
        "    is_1d = 'burgers' in dataset\n",
        "    is_temporal = 'ns' in dataset\n",
        "\n",
        "    # Create model\n",
        "    model = create_model(cfg, use_del=use_del)\n",
        "    del_str = \"DEL-\" if use_del else \"\"\n",
        "    name = f\"{del_str}{model_type.upper()} [{dataset}]\"\n",
        "\n",
        "    # Load data\n",
        "    data = load_data(cfg, device, data_dir)\n",
        "\n",
        "    # Train\n",
        "    if is_1d:\n",
        "        result = train_1d(model, data, cfg, device, name)\n",
        "    elif is_temporal:\n",
        "        T_in = data.get('T_in', 10)\n",
        "        T_out = data.get('T_out', 10)\n",
        "        result = train_2d_ar(model, data, cfg, device, name, T_in, T_out)\n",
        "    else:\n",
        "        result = train_2d(model, data, cfg, device, name)\n",
        "\n",
        "    result['config'] = cfg\n",
        "    result['name'] = name\n",
        "    return result\n",
        "\n",
        "\n",
        "def run_fno_experiments(preset: str = 'thuml', data_dir: str = './data'):\n",
        "    \"\"\"\n",
        "    Run all FNO experiments.\n",
        "\n",
        "    Datasets: burgers_r10, burgers_r1000, darcy, ns_v1e3, ns_v1e5\n",
        "    Models: FNO, DEL-FNO\n",
        "    \"\"\"\n",
        "    print(\"\\n\" + \"═\"*80)\n",
        "    print(\"  FNO EXPERIMENTS\")\n",
        "    print(\"═\"*80)\n",
        "\n",
        "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "    print(f\"\\nDevice: {device}\")\n",
        "    print(f\"Preset: {preset}\")\n",
        "    print(f\"Data directory: {data_dir}\")\n",
        "\n",
        "    datasets = ['burgers_r10', 'ns_v1e5']\n",
        "    results = []\n",
        "\n",
        "    for dataset in datasets:\n",
        "        print(f\"\\n{'━'*80}\")\n",
        "        print(f\"  Dataset: {dataset.upper()}\")\n",
        "        print(f\"{'━'*80}\")\n",
        "\n",
        "        # Run baseline FNO\n",
        "        result_base = run_experiment(dataset, 'fno', preset, use_del=False,\n",
        "                                     device=device, data_dir=data_dir)\n",
        "        results.append(('FNO', dataset, result_base))\n",
        "\n",
        "        # Run DEL-FNO\n",
        "        result_del = run_experiment(dataset, 'fno', preset, use_del=True,\n",
        "                                    device=device, data_dir=data_dir)\n",
        "        results.append(('DEL-FNO', dataset, result_del))\n",
        "\n",
        "    # Summary\n",
        "    print(\"\\n\" + \"═\"*80)\n",
        "    print(\"  FNO RESULTS SUMMARY\")\n",
        "    print(\"═\"*80)\n",
        "    print(f\"\\n{'Model':<15} {'Dataset':<15} {'Best Loss':<12} {'Learned βs'}\")\n",
        "    print(\"─\"*70)\n",
        "    for model_name, dataset, result in results:\n",
        "        betas = result.get('params', {}).get('betas', [])\n",
        "        beta_str = ', '.join([f'{b:.2f}' for b in betas[:4]]) if betas else 'N/A'\n",
        "        print(f\"{model_name:<15} {dataset:<15} {result['best_loss']:.6f}     {beta_str}\")\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def run_transolver_experiments(preset: str = 'thuml', data_dir: str = './data'):\n",
        "    \"\"\"\n",
        "    Run all Transolver experiments.\n",
        "\n",
        "    Datasets: burgers_r10, burgers_r1000, darcy, ns_v1e3, ns_v1e5\n",
        "    Models: Transolver, DEL-Transolver\n",
        "    \"\"\"\n",
        "    print(\"\\n\" + \"═\"*80)\n",
        "    print(\"  TRANSOLVER EXPERIMENTS\")\n",
        "    print(\"═\"*80)\n",
        "\n",
        "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "    print(f\"\\nDevice: {device}\")\n",
        "    print(f\"Preset: {preset}\")\n",
        "    print(f\"Data directory: {data_dir}\")\n",
        "\n",
        "    datasets = ['burgers_r10', 'darcy']\n",
        "    results = []\n",
        "\n",
        "    for dataset in datasets:\n",
        "        print(f\"\\n{'━'*80}\")\n",
        "        print(f\"  Dataset: {dataset.upper()}\")\n",
        "        print(f\"{'━'*80}\")\n",
        "\n",
        "        # Run baseline Transolver\n",
        "        result_base = run_experiment(dataset, 'transolver', preset, use_del=False,\n",
        "                                     device=device, data_dir=data_dir)\n",
        "        results.append(('Transolver', dataset, result_base))\n",
        "\n",
        "        # Run DEL-Transolver\n",
        "        result_del = run_experiment(dataset, 'transolver', preset, use_del=True,\n",
        "                                    device=device, data_dir=data_dir)\n",
        "        results.append(('DEL-Transolver', dataset, result_del))\n",
        "\n",
        "    # Summary\n",
        "    print(\"\\n\" + \"═\"*80)\n",
        "    print(\"  TRANSOLVER RESULTS SUMMARY\")\n",
        "    print(\"═\"*80)\n",
        "    print(f\"\\n{'Model':<18} {'Dataset':<15} {'Best Loss':<12} {'Learned βs'}\")\n",
        "    print(\"─\"*75)\n",
        "    for model_name, dataset, result in results:\n",
        "        betas = result.get('params', {}).get('betas', [])\n",
        "        beta_str = ', '.join([f'{b:.2f}' for b in betas[:4]]) if betas else 'N/A'\n",
        "        print(f\"{model_name:<18} {dataset:<15} {result['best_loss']:.6f}     {beta_str}\")\n",
        "\n",
        "    return results\n",
        "\n",
        "\n",
        "def run_all_experiments(preset: str = 'thuml', data_dir: str = './data'):\n",
        "    \"\"\"\n",
        "    Run complete DELNO-Spectra benchmark.\n",
        "\n",
        "    Section 1: FNO experiments\n",
        "    Section 2: Transolver experiments\n",
        "    \"\"\"\n",
        "    print(\"\\n\" + \"╔\" + \"═\"*78 + \"╗\")\n",
        "    print(\"║\" + \" \"*20 + \"DELNO-Spectra v7: Complete Benchmark\" + \" \"*21 + \"║\")\n",
        "    print(\"║\" + \" \"*20 + \"ICML 2026 Submission\" + \" \"*38 + \"║\")\n",
        "    print(\"╚\" + \"═\"*78 + \"╝\")\n",
        "\n",
        "    timestamp = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
        "    print(f\"\\nStarted: {timestamp}\")\n",
        "    print(f\"Preset: {preset}\")\n",
        "    print(f\"Data directory: {data_dir}\")\n",
        "\n",
        "    all_results = {}\n",
        "\n",
        "    # Section 1: FNO\n",
        "    fno_results = run_fno_experiments(preset=preset, data_dir=data_dir)\n",
        "    all_results['fno'] = fno_results\n",
        "\n",
        "    # Section 2: Transolver\n",
        "    transolver_results = run_transolver_experiments(preset=preset, data_dir=data_dir)\n",
        "    all_results['transolver'] = transolver_results\n",
        "\n",
        "    # Final Summary\n",
        "    print(\"\\n\" + \"╔\" + \"═\"*78 + \"╗\")\n",
        "    print(\"║\" + \" \"*25 + \"FINAL RESULTS SUMMARY\" + \" \"*32 + \"║\")\n",
        "    print(\"╚\" + \"═\"*78 + \"╝\")\n",
        "\n",
        "    print(f\"\\n{'Model':<18} {'Dataset':<15} {'Baseline':<12} {'DEL':<12} {'Δ':<10}\")\n",
        "    print(\"─\"*75)\n",
        "\n",
        "    # Group by dataset for comparison\n",
        "    for model_type in ['fno', 'transolver']:\n",
        "        results = all_results[model_type]\n",
        "        model_name = 'FNO' if model_type == 'fno' else 'Transolver'\n",
        "\n",
        "        for dataset in ['burgers_r100', 'burgers_r1000', 'darcy', 'ns_v1e3', 'ns_v1e5']:\n",
        "            base = next((r for m, d, r in results if m == model_name and d == dataset), None)\n",
        "            del_model = next((r for m, d, r in results if 'DEL' in m and d == dataset), None)\n",
        "\n",
        "            if base and del_model:\n",
        "                delta = (base['best_loss'] - del_model['best_loss']) / base['best_loss'] * 100\n",
        "                delta_str = f\"{delta:+.2f}%\" if delta != 0 else \"0.00%\"\n",
        "                print(f\"{model_name:<18} {dataset:<15} {base['best_loss']:.6f}     {del_model['best_loss']:.6f}     {delta_str}\")\n",
        "\n",
        "    print(\"\\n\" + \"═\"*80)\n",
        "    print(f\"Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\")\n",
        "    print(\"═\"*80)\n",
        "\n",
        "    return all_results\n",
        "\n",
        "\n",
        "# ══════════════════════════════════════════════════════════════════════════════════════════"
      ],
      "execution_count": 42,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HftyYh6FS25j"
      },
      "source": [
        "---\n",
        "# FNO Experiments\n",
        "Run FNO and DEL-FNO"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ki-4WDAGS25j",
        "outputId": "4fc305d9-4275-4523-c4ca-6c7c6a04a14e"
      },
      "source": [
        "fno_results = run_fno_experiments(preset='thuml', data_dir='./data')"
      ],
      "execution_count": 43,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "  FNO EXPERIMENTS\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "\n",
            "Device: cuda\n",
            "Preset: thuml\n",
            "Data directory: ./data\n",
            "\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "  Dataset: BURGERS_R10\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "FNO [burgers_r10] | 221.9K params\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=92.408% Test=64.432% (3s) ★\n",
            "  Ep  100: Train=2.148% Test=3.001% (100s) ★\n",
            "  Ep  200: Train=1.508% Test=1.685% (192s) ★\n",
            "  Ep  300: Train=0.788% Test=0.884% (283s) ★\n",
            "  Ep  400: Train=0.240% Test=0.218% (374s) ★\n",
            "  Ep  500: Train=0.038% Test=0.098% (465s) ★\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "DEL-FNO [burgers_r10] | 222.0K params\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=95.976% Test=80.014% (2s) ★ | β=0.99, β=2.00\n",
            "  Ep  100: Train=2.051% Test=1.869% (145s) ★ | β=1.01, β=2.00\n",
            "  Ep  200: Train=1.672% Test=1.194% (289s) ★ | β=1.23, β=2.00\n",
            "  Ep  300: Train=0.768% Test=0.580% (432s) ★ | β=1.33, β=2.00\n",
            "  Ep  400: Train=0.178% Test=0.158% (575s) ★ | β=1.35, β=2.00\n",
            "  Ep  500: Train=0.034% Test=0.083% (719s) ★ | β=1.35, β=2.00\n",
            "\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "  Dataset: NS_V1E5\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "FNO [ns_v1e5] | 4744.5K | AR: T_in=10→rollout 10\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=56.488% Test=44.638% (1s) ★\n",
            "  Ep  100: Train=5.656% Test=15.634% (59s) ★\n",
            "  Ep  200: Train=4.458% Test=12.881% (118s) ★\n",
            "  Ep  300: Train=4.201% Test=11.677% (177s) ★\n",
            "  Ep  400: Train=3.828% Test=9.749% (237s) ★\n",
            "  Ep  500: Train=3.914% Test=9.645% (295s) ★\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "DEL-FNO [ns_v1e5] | 4745.0K | AR: T_in=10→rollout 10\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=53.030% Test=75.187% (2s) ★ | β=1.01, β=0.97, β=2.00, β=2.00\n",
            "  Ep  100: Train=5.585% Test=15.734% (127s) ★ | β=0.92, β=0.93, β=2.00, β=2.00\n",
            "  Ep  200: Train=4.372% Test=12.678% (254s) ★ | β=0.95, β=0.95, β=2.00, β=2.00\n",
            "  Ep  300: Train=4.090% Test=9.809% (380s) ★ | β=0.98, β=0.98, β=2.00, β=2.00\n",
            "  Ep  400: Train=3.694% Test=9.684% (507s) ★ | β=0.99, β=0.99, β=2.00, β=2.00\n",
            "  Ep  500: Train=3.764% Test=9.434% (633s) ★ | β=0.99, β=1.00, β=2.00, β=2.00\n",
            "\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "  FNO RESULTS SUMMARY\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "\n",
            "Model           Dataset         Best Loss    Learned βs\n",
            "──────────────────────────────────────────────────────────────────────\n",
            "FNO             burgers_r10     0.000977     N/A\n",
            "DEL-FNO         burgers_r10     0.000826     1.35, 2.00\n",
            "FNO             ns_v1e5         0.096448     N/A\n",
            "DEL-FNO         ns_v1e5         0.094341     0.99, 1.00, 2.00, 2.00\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hi1-jFtoS25j"
      },
      "source": [
        "---\n",
        "# Transolver Experiments\n",
        "Run Transolver and DEL-Transolver on all 5 datasets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5Bs2eranS25j",
        "outputId": "2c8beee0-031f-4382-9ce2-b671c91959c0"
      },
      "source": [
        "transolver_results = run_transolver_experiments(preset='thuml', data_dir='./data')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "  TRANSOLVER EXPERIMENTS\n",
            "════════════════════════════════════════════════════════════════════════════════\n",
            "\n",
            "Device: cuda\n",
            "Preset: thuml\n",
            "Data directory: ./data\n",
            "\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "  Dataset: BURGERS_R10\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "TRANSOLVER [burgers_r10] | 959.7K params\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=46.902% Test=40.673% (4s) ★\n",
            "  Ep  100: Train=8.125% Test=7.195% (403s) ★\n",
            "  Ep  200: Train=4.153% Test=4.114% (806s) ★\n",
            "  Ep  300: Train=1.806% Test=2.195% (1209s) ★\n",
            "  Ep  400: Train=0.618% Test=0.934% (1613s) ★\n",
            "  Ep  500: Train=0.150% Test=0.654% (2016s) ★\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "DEL-TRANSOLVER [burgers_r10] | 959.9K params\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=45.951% Test=40.148% (5s) ★ | β=0.99, β=2.00\n",
            "  Ep  100: Train=7.405% Test=5.849% (459s) ★ | β=0.95, β=2.00\n",
            "  Ep  200: Train=3.611% Test=4.472% (920s) ★ | β=0.99, β=2.00\n",
            "  Ep  300: Train=1.515% Test=1.565% (1379s) ★ | β=0.99, β=2.00\n",
            "  Ep  400: Train=0.580% Test=0.834% (1841s) ★ | β=0.98, β=2.00\n",
            "  Ep  500: Train=0.150% Test=0.530% (2300s) ★ | β=0.97, β=2.00\n",
            "\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "  Dataset: DARCY\n",
            "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
            "\n",
            "────────────────────────────────────────────────────────────\n",
            "TRANSOLVER [darcy] | 959.8K params [normalized]\n",
            "────────────────────────────────────────────────────────────\n",
            "  Ep    1: Train=67.747% Test=17.191% (18s) ★\n"
          ]
        }
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.0"
    },
    "colab": {
      "provenance": [],
      "gpuType": "A100"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}