#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations

"""
Multi-Dimensional Porous Medium Equation (PME) with a Random-Features Network (RaNN).
"""

import math
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# -------------------------
# Global Config
# -------------------------
torch.manual_seed(42)
np.random.seed(42)

# Problem params
SPATIAL_DIM = 1                    
M_EXPONENT = 2.0                   # m in PME (m > 1)
T0_IC = 1e-2                       # small positive time for IC (avoid singularity)
T_DOMAIN = 0.05                   # t in [0, T]
M_RUNS = 5
N_EVAL = 20000

# Domain Ω = ∏ [0, L_k]
L_VEC = [1.0] * SPATIAL_DIM        # lengths per dimension, defaults to [1.0]*d

# RaNN params
USE_FOURIER_FEATURES = True        # set False to use a non-Fourier random feature map
NON_FOURIER_ACTIVATION = 'tanh'    # used if USE_FOURIER_FEATURES=False
NUM_RANDOM_FEATURES = 2500         # total hidden features; for Fourier must be EVEN (cos+sin)
FEATURE_SCALE = 10.0               # init scale for random weights
POSITIVE_HEAD = 'square'           # 'square' | 'softplus' | 'relu' | 'none'

# Input anisotropic scaling (per spatial dim, then time)
INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

# Collocation toggles
USE_CUSTOM_COLLOCATION = True      # set False for uniform collocation
FRONT_BAND_REL = 0.02              # relative band thickness around front
SUPPORT_FRAC = 0.60
RESAMPLE_EVERY = 500   # epochs

# Training params
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5000
N_PDE = 2000                       # more points for higher d
N_IC = 1000
N_BC = 500                        # split across 2*d faces

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 200.0
LAMBDA_BC = 1.0

# L-BFGS fine-tuning (last M epochs)
LBFGS_LAST_M_EPOCHS = 50      
LBFGS_LR = 1.0
LBFGS_MAX_ITER = 20
LBFGS_HISTORY = 50


# Dimensions to run
D_LIST = [1,2,3,4,5]

def hyperparams_for_dim(d: int):
    """
    Return the dimension-dependent hyperparameters.
    Right now they are just placeholders – adjust as you like.
    """
    if d == 1:
        tau = 0.05
    elif d == 2:
        tau = 0.025
    else:
        tau = 0.01
        
    return {
        "NUM_RANDOM_FEATURES": 2500 * d if d < 4 else 7500,   # must be even for Fourier
        "N_PDE": 2000 * d if d < 4 else 8000,
        "N_IC": 1000 * d if d < 4 else 4000,
        "N_BC": 500 * d if d < 4 else 2000,
        "T_DOMAIN": tau
    }

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")


# -------------------------
# Utilities
# -------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return torch.tanh
    if name == 'relu':
        return torch.relu
    if name == 'sigmoid':
        return torch.sigmoid
    raise ValueError(f"Unsupported activation: {name}")


def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)


# -------------------------
# RaNN definition (multi-d, readout-only trainable)
# -------------------------
class RFNNnD(nn.Module):
    def __init__(
        self,
        spatial_dim: int,
        num_random_features: int,
        feature_scale: float = 1.0,
        *,
        use_fourier: bool = True,
        non_fourier_activation: str = 'tanh',
        positive: str = 'square',
        input_scales: list[float] | Tuple[float, ...] = None,
        seed: int = 12345,
    ) -> None:
        super().__init__()
        torch.manual_seed(seed)

        self.d = spatial_dim
        self.input_dim = self.d + 1    # (x1,...,xd, t)
        self.output_dim = 1
        self.num_random_features = int(num_random_features)
        self.use_fourier = bool(use_fourier)
        self.positive = positive
        self.non_fourier_activation = non_fourier_activation.lower()

        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales).view(1, -1))

        if self.use_fourier:
            assert self.num_random_features % 2 == 0, "For Fourier features, NUM_RANDOM_FEATURES must be even."
            rf = self.num_random_features // 2
            # Fixed random frequencies & phases (frozen)
            self.random_weights = torch.randn(rf, self.input_dim, device=DEVICE) * feature_scale  # [rf, d+1]
            self.random_biases = torch.rand(rf, device=DEVICE) * 2 * math.pi                      # [rf]
        else:
            # Fixed random affine map (frozen)
            self.random_weights = torch.randn(self.num_random_features, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(self.num_random_features, device=DEVICE)

        self.random_weights = nn.Parameter(self.random_weights, requires_grad=False)
        self.random_biases = nn.Parameter(self.random_biases, requires_grad=False)

        # Trainable linear readout
        self.output_layer = nn.Linear(self.num_random_features, self.output_dim, device=DEVICE)

    def _feature_map(self, xt: torch.Tensor) -> torch.Tensor:
        x_scaled = xt * self.input_scales  # [N, d+1]
        proj = x_scaled @ self.random_weights.t()

        if self.use_fourier:
            cosf = torch.cos(proj + self.random_biases)  # [N, rf]
            sinf = torch.sin(proj + self.random_biases)  # [N, rf]
            scale = (self.num_random_features // 2) ** (-0.5)
            return scale * torch.cat([cosf, sinf], dim=1)
        else:
            act = get_activation(self.non_fourier_activation)
            return act(proj + self.random_biases)

    def forward(self, xt: torch.Tensor) -> torch.Tensor:
        phi = self._feature_map(xt)           # [N, M]
        raw = self.output_layer(phi)          # [N, 1]

        if self.positive == 'square':
            return raw ** 2
        elif self.positive == 'softplus':
            return torch.nn.functional.softplus(raw, beta=1.0)
        elif self.positive == 'relu':
            return torch.relu(raw)
        elif self.positive == 'none':
            return raw
        else:
            raise ValueError("positive must be 'square', 'softplus', 'relu', or 'none'")


# -------------------------
# Barenblatt in d-D
# -------------------------
def barenblatt_ic_dd(
    x: torch.Tensor,           # [N, d] spatial points
    L_vec: list[float],        # domain lengths
    m: float,
    t0: float,
    support_fraction: float = SUPPORT_FRAC,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 2.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t0 ** (2.0 * beta))
    z_clamped = torch.clamp(z, min=0.0)
    u_t0 = (t0 ** (-alpha)) * (z_clamped ** (1.0 / (m - 1.0)))
    return u_t0


def barenblatt_true_shifted_dd(
    t_model: torch.Tensor,     # [N, 1] or scalar
    x: torch.Tensor,           # [N, d]
    L_vec: list[float],
    m: float,
    t0: float,
    support_fraction: float = SUPPORT_FRAC,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 2.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    if t_phys.ndim == 0:
        t_phys = t_phys * torch.ones(x.shape[0], 1, device=x.device, dtype=x.dtype)

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u


# -------------------------
# Collocation Builders
# -------------------------
def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L


def sample_near_front_nd(n: int, L_vec: list[float], t0: float, T: float, m: float, band_rel: float) -> torch.Tensor:
    """Near-front sampling in d-D (radius heuristic)."""
    d = len(L_vec)
    lam = 2.0 + d * (m - 1.0)
    beta = 1.0 / lam
    Lmin = min(L_vec)
    R0 = SUPPORT_FRAC * (Lmin / 2.0)
    t = (torch.rand(n, 1, device=DEVICE) ** 3) * T
    R = R0 * ((t0 + t) / t0) ** beta

    dirs = torch.randn(n, d, device=DEVICE)
    dirs = dirs / (dirs.norm(dim=1, keepdim=True) + 1e-12)

    delta = band_rel * Lmin
    radii = R.squeeze(1) + (2 * torch.rand(n, device=DEVICE) - 1.0) * delta
    radii = radii.clamp(min=0.0)

    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    x = x0 + dirs * radii.view(-1, 1)

    L = _to_tensor(L_vec).view(1, d)
    x = torch.clamp(x, min=torch.zeros(1, d, device=DEVICE), max=L)
    return torch.cat([x, t], dim=1)



def build_collocation_sets_nd(use_custom: bool, d: int):
    if use_custom:
        n_front = int(0.4 * N_PDE)
        n_bulk = N_PDE - n_front
        pde_xt_front = sample_near_front_nd(n_front, L_VEC, T0_IC, T_DOMAIN, M_EXPONENT, FRONT_BAND_REL)
        x_bulk = rand_in_box(n_bulk, L_VEC)
        t_bulk = (torch.rand(n_bulk, 1, device=DEVICE) ** 3) * T_DOMAIN
        pde_xt_bulk = torch.cat([x_bulk, t_bulk], dim=1)
        pde_xt = torch.cat([pde_xt_front, pde_xt_bulk], dim=0)
    else:
        x_bulk = rand_in_box(N_PDE, L_VEC)
        t_bulk = torch.rand(N_PDE, 1, device=DEVICE) * T_DOMAIN
        pde_xt = torch.cat([x_bulk, t_bulk], dim=1)

    # IC @ t=0
    # 1. Get the correct radius
    R0_true = SUPPORT_FRAC * (min(L_VEC) / 2.0)
    
    # 2. Sample from a d-D Gaussian
    ic_x_unit_ball = torch.randn(N_IC, d, device=DEVICE)
    ic_x_unit_ball = ic_x_unit_ball / (ic_x_unit_ball.norm(dim=1, keepdim=True) + 1e-12)
    
    # 3. Give them random radii *within* R0
    radii = (torch.rand(N_IC, 1, device=DEVICE) ** (1/d)) * R0_true
    
    # 4. Center and clamp to box (just in case)
    x0_center = _to_tensor([0.5 * L for L in L_VEC]).view(1, d)
    ic_x = x0_center + ic_x_unit_ball * radii
    
    # Clamp to be safe, though most should be inside [0,1]^d
    L_tensor = _to_tensor(L_VEC).view(1, d)
    ic_x = torch.clamp(ic_x, min=torch.zeros(1, d, device=DEVICE), max=L_tensor)

    ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    ic_xt = torch.cat([ic_x, ic_t], dim=1)
    
    # Now this will be non-zero!
    u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)

    # BC on faces: x_i=0 and x_i=L_i
    faces = []
    per_face = max(1, N_BC // (2 * d))
    for i in range(d):
        x_rest = rand_in_box(per_face, L_VEC)
        x0_face = x_rest.clone(); x0_face[:, i] = 0.0
        xL_face = x_rest.clone(); xL_face[:, i] = L_VEC[i]
        t_face = torch.rand(per_face, 1, device=DEVICE) * T_DOMAIN
        faces.append(torch.cat([x0_face, t_face], dim=1))
        faces.append(torch.cat([xL_face, t_face], dim=1))
    bc_xt = torch.cat(faces, dim=0)

    return pde_xt, ic_xt, u_ic_true, bc_xt


# -------------------------
# PDE residuals
# -------------------------
def pde_residual_autograd_nd(model: nn.Module, xt: torch.Tensor, m_exp: float, d: int):
    """Generic autograd residual for any feature map and positivity wrapper."""
    xt_req = xt.clone().detach().requires_grad_(True)    # [N, d+1]
    u = model(xt_req)                                    # [N, 1]

    grads_u = torch.autograd.grad(u, xt_req, torch.ones_like(u), create_graph=True)[0]  # [N, d+1]
    u_t = grads_u[:, d:d+1]
    v = u.pow(m_exp)

    grads_v = torch.autograd.grad(v, xt_req, torch.ones_like(v), create_graph=True)[0]
    v_x = grads_v[:, :d]  # [N, d]

    # Laplacian
    v_xx_terms = []
    for i in range(d):
        vi = v_x[:, i:i+1]
        grads_vi = torch.autograd.grad(vi, xt_req, torch.ones_like(vi), create_graph=True)[0]
        v_xx_terms.append(grads_vi[:, i:i+1])
    v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N,1]

    R = u_t - v_lap
    return R, u, v_x


def pde_residual_fast_nd(model: RFNNnD, xt: torch.Tensor, m_exp: float, d: int):
    """
    Closed-form residual for Fourier features with 'square' head.
    Assumes: model.use_fourier == True and model.positive == 'square'.
    """
    assert model.use_fourier, "fast residual requires Fourier features"
    assert model.positive == 'square', "fast residual currently assumes positive='square'"
    # Shorthands
    rf = model.num_random_features // 2
    W = model.random_weights[:rf, :]  # [rf, d+1]
    b = model.random_biases[:rf]      # [rf]
    scales = model.input_scales.view(-1)  # [d+1]

    scale_rf = (rf) ** (-0.5)

    # Project & basis
    proj = (xt * model.input_scales) @ W.t()  # [N, rf]
    cosv = torch.cos(proj + b)                # [N, rf]
    sinv = torch.sin(proj + b)                # [N, rf]

    # Readout weights split
    Wout = model.output_layer.weight.squeeze(0)  # [2*rf]
    a = scale_rf * Wout[:rf]                      # [rf]
    c = scale_rf * Wout[rf:]                      # [rf]
    bias = model.output_layer.bias                # [1]

    # raw field
    raw = (cosv @ a) + (sinv @ c) + bias.squeeze()  # [N]

    # derivatives per coordinate
    raw_t = None
    raw_xx_terms = []
    for k in range(d + 1):
        proj_k = scales[k] * W[:, k]  # [rf]
        raw_k = (-sinv * proj_k) @ a + (cosv * proj_k) @ c  # [N]
        raw_kk = (-(cosv * (proj_k ** 2)) @ a) + (-(sinv * (proj_k ** 2)) @ c)  # [N]

        if k == d:  # time
            raw_t = raw_k
        else:
            raw_xx_terms.append(2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk)  # u_xkxk term (via chain rule)

    # u and its derivatives
    u = raw * raw                      # [N]
    u_t = 2.0 * raw * raw_t            # [N]
    u_xx_sum = torch.stack(raw_xx_terms, dim=0).sum(dim=0)  # sum over spatial dims

    # v = u^m and Laplacian of v via 1D formula per dim (since u_xx computed via above)
    if m_exp == 1.0:
        v_lap = u_xx_sum
    else:
        u_safe = u + 0.0
        u_m1 = u_safe.pow(m_exp - 1.0)
        u_m2 = u_safe.pow(m_exp - 2.0)

        # We also need u_x per dim for the (m-1) term; recompute raw_k per dim
        v_xx_terms = []
        for k in range(d):
            proj_k = scales[k] * W[:, k]
            raw_k = (-sinv * proj_k) @ a + (cosv * proj_k) @ c  # [N]
            raw_kk = (-(cosv * (proj_k ** 2)) @ a) + (-(sinv * (proj_k ** 2)) @ c)  # [N]

            u_x_k = 2.0 * raw * raw_k
            u_xx_k = 2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk
            v_xx_k = m_exp * (m_exp - 1.0) * u_m2 * (u_x_k * u_x_k) + m_exp * u_m1 * u_xx_k
            v_xx_terms.append(v_xx_k)

        v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N]

    R = u_t - v_lap
    return R.unsqueeze(1), u.unsqueeze(1)


# Simple feature cache for Fourier fast path
class FeatureCache:
    def __init__(self, model, xt, d):
        rf = model.num_random_features // 2
        W = model.random_weights[:rf, :]
        b = model.random_biases[:rf]
        self.scales = model.input_scales.view(-1)
        self.proj = (xt * model.input_scales) @ W.t()  # [N, rf]
        self.cosv = torch.cos(self.proj + b)           # [N, rf]
        self.sinv = torch.sin(self.proj + b)           # [N, rf]
        self.W = W
        self.rf = rf
        self.d = d

# Cached fast residual using precomputed cosv/sinv
def pde_residual_fast_nd_cached(model, cache, m_exp):
    rf, d = cache.rf, cache.d
    cosv, sinv, W, scales = cache.cosv, cache.sinv, cache.W, cache.scales
    scale_rf = (rf)**(-0.5)
    Wout = model.output_layer.weight.squeeze(0)
    a = scale_rf * Wout[:rf]; c = scale_rf * Wout[rf:]
    bias = model.output_layer.bias

    raw = (cosv @ a) + (sinv @ c) + bias           # [N]
    u   = raw * raw
    # time deriv
    proj_t = scales[d] * W[:, d]
    raw_t  = (-sinv * proj_t) @ a + (cosv * proj_t) @ c
    u_t    = 2.0 * raw * raw_t

    # spatial Laplacian of v = u^m
    u_safe = u.clamp_min(0.0)
    u_m1 = u_safe.pow(m_exp - 1.0)
    u_m2 = u_safe.pow(m_exp - 2.0) if m_exp != 1.0 else None
    v_xx_terms = []
    for k in range(d):
        proj_k = scales[k] * W[:, k]
        raw_k  = (-sinv * proj_k) @ a + (cosv * proj_k) @ c
        raw_kk = (-(cosv * (proj_k**2)) @ a) + (-(sinv * (proj_k**2)) @ c)
        u_xk   = 2.0 * raw * raw_k
        u_xxk  = 2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk
        if m_exp == 1.0:
            v_xxk = u_xxk
        else:
            v_xxk = m_exp*(m_exp-1.0)*u_m2*(u_xk*u_xk) + m_exp*u_m1*u_xxk
        v_xx_terms.append(v_xxk)
    v_lap = torch.stack(v_xx_terms, 0).sum(0)
    R = u_t - v_lap
    return R.unsqueeze(1), u.unsqueeze(1)



def residual_and_flux(model: RFNNnD, xt: torch.Tensor, m_exp: float, d: int):
    if model.use_fourier and model.positive == 'square':
        return pde_residual_fast_nd(model, xt, m_exp, d)
    else:
        # generic autograd fallback (also returns grad_v if needed)
        R, u, _v_x = pde_residual_autograd_nd(model, xt, m_exp, d)
        return R, u




# =========================
# Multi-run Training & Evaluation (RaNN)
# =========================
import random
from datetime import datetime

for SPATIAL_DIM in D_LIST:
    # Update derived quantities for this dimension
    L_VEC = [1.0] * SPATIAL_DIM
    INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

    # Dimension-dependent hyperparams
    dims_hp = hyperparams_for_dim(SPATIAL_DIM)
    NUM_RANDOM_FEATURES = dims_hp["NUM_RANDOM_FEATURES"]
    N_PDE = dims_hp["N_PDE"]
    N_IC = dims_hp["N_IC"]
    N_BC = dims_hp["N_BC"]
    T_DOMAIN = dims_hp["T_DOMAIN"]

    if SPATIAL_DIM >= 4:
        LAMBDA_IC = 200.0

    print(f"\n\n=== Starting runs for d = {SPATIAL_DIM} ===")

    # =========================
    # Multi-run Training & Evaluation (RaNN)
    # =========================
    import random
    from datetime import datetime

    BASE_SEED = 12345
    LOG_FILE = f"rann_multirun_{SPATIAL_DIM}D_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

    def _set_seeds(s: int):
        torch.manual_seed(s)
        np.random.seed(s)
        random.seed(s)

    spacetime_errors = []
    finaltime_errors = []
    run_durations = []

    with open(LOG_FILE, "w", encoding="utf-8") as logf:

        def _log(msg: str):
            print(msg)
            logf.write(msg + "\n")
            logf.flush()

        # Log configuration header
        _log(f"== {SPATIAL_DIM}D RaNN for PME (u_t - Δ(u^m) = 0) ==")
        _log(f"m (M_EXPONENT)         : {M_EXPONENT}")
        _log(f"T0_IC / T_DOMAIN       : {T0_IC} / {T_DOMAIN}")
        _log(f"L_VEC                  : {L_VEC}")
        _log(f"USE_FOURIER_FEATURES  : {USE_FOURIER_FEATURES}")
        _log(f"NUM_RANDOM_FEATURES   : {NUM_RANDOM_FEATURES}")
        _log(f"FEATURE_SCALE         : {FEATURE_SCALE}")
        _log(f"POSITIVE_HEAD         : {POSITIVE_HEAD}")
        _log(f"INPUT_SCALES          : {INPUT_SCALES}")
        _log(f"USE_CUSTOM_COLLOCATION: {USE_CUSTOM_COLLOCATION}")
        _log(f"LEARNING_RATE         : {LEARNING_RATE}")
        _log(f"NUM_EPOCHS            : {NUM_EPOCHS}")
        _log(f"N_PDE / N_IC / N_BC   : {N_PDE} / {N_IC} / {N_BC}")
        _log(f"N_EVAL (MC samples)   : {N_EVAL}")
        _log(f"M_RUNS                : {M_RUNS}")
        _log("")

        for run_idx in range(1, M_RUNS + 1):
            _log(f"--- Run {run_idx}/{M_RUNS} for d={SPATIAL_DIM} ---")
            seed = BASE_SEED + run_idx
            _set_seeds(seed)
            _log(f"Seed: {seed}")

            # Fresh collocation sets each run
            pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(
                USE_CUSTOM_COLLOCATION, SPATIAL_DIM
            )

            # Fresh model & optimizer each run (only readout is trainable)
            model = RFNNnD(
                SPATIAL_DIM,
                NUM_RANDOM_FEATURES,
                feature_scale=FEATURE_SCALE,
                use_fourier=USE_FOURIER_FEATURES,
                non_fourier_activation=NON_FOURIER_ACTIVATION,
                positive=POSITIVE_HEAD,
                input_scales=INPUT_SCALES,
                seed=seed,
            ).to(DEVICE)
            optimizer = torch.optim.Adam(model.output_layer.parameters(), lr=LEARNING_RATE)
            mse = nn.MSELoss()

            # Build caches once (after building collocation sets)
            if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
                pde_cache = FeatureCache(model, pde_xt, SPATIAL_DIM)
                bc_cache  = FeatureCache(model, bc_xt,  SPATIAL_DIM)

            ###########################################
            # Hybrid Adam → L-BFGS Training
            ###########################################
            _log("Training...")
            t0 = time.time()
            
            # Adam for first NUM_EPOCHS - M epochs
            M_LBFGS = min(LBFGS_LAST_M_EPOCHS, NUM_EPOCHS)
            adam_epochs = NUM_EPOCHS - M_LBFGS
            
            adam_optim = torch.optim.Adam(model.output_layer.parameters(), lr=LEARNING_RATE)
            
            lbfgs_optim = None
            if M_LBFGS > 0:
                lbfgs_optim = torch.optim.LBFGS(
                    model.output_layer.parameters(),
                    lr=LBFGS_LR,
                    max_iter=LBFGS_MAX_ITER,
                    history_size=LBFGS_HISTORY,
                    line_search_fn="strong_wolfe",
                )
            
            # For stable logging during LBFGS
            last_loss = last_pde = last_ic = last_bc = None
            
            def compute_loss():
                """Compute loss (PDE + IC + BC) with gradients enabled."""
                # PDE loss
                if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
                    R_pde, _ = pde_residual_fast_nd_cached(model, pde_cache, M_EXPONENT)
                else:
                    R_pde, _ = residual_and_flux(model, pde_xt, M_EXPONENT, SPATIAL_DIM)
                loss_pde = mse(R_pde, torch.zeros_like(R_pde))
            
                # IC
                u_ic_pred = model(ic_xt)
                loss_ic = mse(u_ic_pred, u_ic_true)
            
               # BC loss
                if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
                    R_bc, _ = pde_residual_fast_nd_cached(model, bc_cache, M_EXPONENT)
                    # enforce PDE residual at BC as BC loss
                    loss_bc = mse(R_bc, torch.zeros_like(R_bc))
                else:
                    xt_req = bc_xt.clone().detach().requires_grad_(True)
                    u_bc = model(xt_req)
                    v_bc = u_bc.pow(M_EXPONENT)
                    grads_v = torch.autograd.grad(v_bc, xt_req, torch.ones_like(v_bc), create_graph=True)[0]
                    v_x_bc = grads_v[:, :SPATIAL_DIM]
                
                    per_face = max(1, N_BC // (2 * SPATIAL_DIM))
                    losses = []
                    offset = 0
                    for i in range(SPATIAL_DIM):
                        sel = slice(offset, offset + per_face)
                        losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                        offset += per_face
                        sel = slice(offset, offset + per_face)
                        losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                        offset += per_face
                    loss_bc = torch.stack(losses).mean()

            
                total = (LAMBDA_PDE * loss_pde +
                         LAMBDA_IC  * loss_ic +
                         LAMBDA_BC  * loss_bc)
                return total, loss_pde, loss_ic, loss_bc
            
            
            ###########################################
            # 1) Adam Phase
            ###########################################
            for epoch in range(1, adam_epochs + 1):
                # --- resample collocation sets every RESAMPLE_EVERY epochs ---
                if epoch % RESAMPLE_EVERY == 1 and epoch != 1:
                    pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(
                        USE_CUSTOM_COLLOCATION, SPATIAL_DIM
                    )
                    if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
                        pde_cache = FeatureCache(model, pde_xt, SPATIAL_DIM)
                        bc_cache  = FeatureCache(model, bc_xt,  SPATIAL_DIM)
                        
                adam_optim.zero_grad(set_to_none=True)
                total, loss_pde, loss_ic, loss_bc = compute_loss()
                total.backward()
                adam_optim.step()
            
                if epoch % 1000 == 0 or epoch == 1:
                    _log(f"[Adam] Epoch {epoch}/{NUM_EPOCHS} "
                         f"| Total: {total.item():.3e} "
                         f"| PDE: {loss_pde.item():.3e} "
                         f"| IC: {loss_ic.item():.3e} "
                         f"| BC: {loss_bc.item():.3e}")

            
            
            
            ###########################################
            # 2) L-BFGS Phase
            ###########################################
            metrics = {
                "total": None,
                "pde":   None,
                "ic":    None,
                "bc":    None,
            }
            
            for epoch in range(adam_epochs + 1, NUM_EPOCHS + 1):
            
                def closure():
                    lbfgs_optim.zero_grad(set_to_none=True)
                    total, lp, li, lb = compute_loss()
                    total.backward()
            
                    # save for logging (mutate dict keys, no nonlocal needed)
                    metrics["total"] = total.detach()
                    metrics["pde"]   = lp.detach()
                    metrics["ic"]    = li.detach()
                    metrics["bc"]    = lb.detach()
            
                    return total
            
                lbfgs_optim.step(closure)
            
                if epoch % 100 == 0 or epoch == adam_epochs + 1 or epoch == NUM_EPOCHS:
                    _log(
                        f"[LBFGS] Epoch {epoch}/{NUM_EPOCHS} "
                        f"| Total: {metrics['total'].item():.3e} "
                        f"| PDE: {metrics['pde'].item():.3e} "
                        f"| IC: {metrics['ic'].item():.3e} "
                        f"| BC: {metrics['bc'].item():.3e}"
                    )
            
            dur = time.time() - t0
            _log(f"Finished run {run_idx} in {dur:.2f}s")
            run_durations.append(dur)



                # ---- Evaluation ----
            model.eval()
    
            # Monte Carlo space–time L2 error (uniform)
            x_eval = rand_in_box(N_EVAL, L_VEC)
            t_eval = torch.rand(N_EVAL, 1, device=DEVICE) * T_DOMAIN
            xt_eval = torch.cat([x_eval, t_eval], dim=1)
    
            with torch.no_grad():
                u_pred_eval = model(xt_eval)
                u_true_eval = barenblatt_true_shifted_dd(t_eval, x_eval, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)
    
            num = torch.sum((u_pred_eval - u_true_eval) ** 2)
            den = torch.sum(u_true_eval ** 2) + 1e-16
            rel_L2_spacetime = torch.sqrt(num / den).item()
            spacetime_errors.append(rel_L2_spacetime)
            _log(f"Rel L2 (space–time, MC): {rel_L2_spacetime:.6e}")
    
            # Final-time spatial L2 error
            x_evalT = rand_in_box(N_EVAL, L_VEC)
            tT = torch.full((N_EVAL, 1), T_DOMAIN, device=DEVICE)
            with torch.no_grad():
                u_pred_T = model(torch.cat([x_evalT, tT], dim=1))
                u_true_T = barenblatt_true_shifted_dd(tT, x_evalT, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)
    
            numT = torch.sum((u_pred_T - u_true_T) ** 2)
            denT = torch.sum(u_true_T ** 2) + 1e-16
            rel_L2_final = torch.sqrt(numT / denT).item()
            finaltime_errors.append(rel_L2_final)
            _log(f"Rel L2 (final time T={T_DOMAIN}): {rel_L2_final:.6e}")
            _log("")
    
            # (Optional) free GPU memory between runs
            del model
            if DEVICE.type == "cuda":
                torch.cuda.empty_cache()
        
        # ---- Summary stats ----
        spacetime_np = np.array(spacetime_errors, dtype=np.float64)
        finaltime_np = np.array(finaltime_errors, dtype=np.float64)
    
        st_mean = float(spacetime_np.mean())
        st_std  = float(spacetime_np.std(ddof=1)) if len(spacetime_np) > 1 else 0.0
        ft_mean = float(finaltime_np.mean())
        ft_std  = float(finaltime_np.std(ddof=1)) if len(finaltime_np) > 1 else 0.0
    
        _log("=== Summary over runs ===")
        _log(f"Space–time rel L2: mean={st_mean:.6e}, std={st_std:.6e}")
        _log(f"Final-time rel L2: mean={ft_mean:.6e}, std={ft_std:.6e}")
        _log(f"Durations (s)    : {', '.join(f'{d:.2f}' for d in run_durations)}")
        _log(f"Saved full log to: {LOG_FILE}")


