from __future__ import annotations

"""
Multi-Dimensional PME PINN with Fourier features.
"""

import math
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# -------------------------
# Global params
# -------------------------
torch.manual_seed(42)
np.random.seed(42)

# Problem params
SPATIAL_DIM = 1                  
M_EXPONENT = 2.0                   # m in PME (m > 1)
T0_IC = 1e-2                       # small positive time for IC (avoid singularity)
T_DOMAIN = 0.05                    # T in [0, T]
M_RUNS = 5

# Domain
L_VEC = [1.0] * SPATIAL_DIM        # lengths per dimension, default is [1.0]*d

# PINN architecture params
USE_FOURIER_FEATURES = True        # set False for standard MLP
RF_COUNT = 32                    # number of Fourier modes (first layer), feature dim = 2*RF_COUNT
FOURIER_INIT_SCALE = 10.0          # init scale for frequencies

NON_FOURIER_ACTIVATION = 'tanh'    # used when USE_FOURIER_FEATURES=False

HIDDEN_DEPTH = 0                   # number of extra hidden layers after first mapping
HIDDEN_WIDTH = 0                   # width of those hidden layers

POSITIVE_HEAD = 'square'           # 'square' | 'softplus' | 'relu' | 'none'

# Input anisotropic scaling (per spatial dim, then time)
INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

# Collocation toggles
USE_CUSTOM_COLLOCATION = True      #  False means uniform collocation
FRONT_BAND_REL = 0.02              
SUPPORT_FRAC = 0.60

# Training params
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5000
N_PDE = 2000                    
N_IC = 1000
N_BC = 500                       

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 200.0
LAMBDA_BC = 1.0

# L-BFGS fine-tuning
LBFGS_LAST_M_EPOCHS = 200   
LBFGS_LR = 1.0             
LBFGS_MAX_ITER = 20          

# Dimensions to run
D_LIST = [4,5]

# dim -> [time, width]
dim_to_params = { 1 : [0.05, 20], 2 : [0.025, 30], 3 : [0.01, 40], 4 :[0.01, 48], 5 : [0.01, 54], 10 : [0.01, 100] }

def hyperparams_for_dim(d: int):
    """
    Return the dimension-dependent hyperparameters.
    Right now they are just placeholders – adjust as you like.
    """
    tau, width = dim_to_params[d]
        
    return {
        "RF_COUNT": 1250 * d if d < 4 else 3750,   # must be even for Fourier
        "N_PDE": 2000 * d if d < 4 else 8000,
        "N_IC": 1000 if d < 4 else 4000,
        "N_BC": 500 if d < 4 else 2000,
        "T_DOMAIN": tau,
        "HIDDEN_WIDTH": 0
    }


# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")


# -------------------------
# Utilities
# -------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return torch.tanh
    if name == 'relu':
        return torch.relu
    if name == 'sigmoid':
        return torch.sigmoid
    raise ValueError(f"Unsupported activation: {name}")


def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)

def init_xavier(layer):
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

# -------------------------
# PINN definition 
# -------------------------
class PINNnD(nn.Module):
    def __init__(
        self,
        spatial_dim: int,
        *,
        use_fourier: bool,
        rf_count: int = 0,
        fourier_init_scale: float = 1.0,
        non_fourier_activation: str = 'tanh',
        hidden_depth: int = 0,
        hidden_width: int = 256,
        positive: str = 'square',
        input_scales: list[float] | Tuple[float, ...] = None,
        seed: int = 12345,
    ) -> None:
        super().__init__()
        torch.manual_seed(seed)

        self.d = spatial_dim
        self.input_dim = self.d + 1   # (x1,...,xd, t)
        self.output_dim = 1

        self.use_fourier = use_fourier
        self.hidden_depth = int(hidden_depth)
        self.hidden_width = int(hidden_width)
        self.positive = positive
        self.non_fourier_activation = non_fourier_activation.lower()

        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales))

        if self.use_fourier:
            assert rf_count > 0 and rf_count % 1 == 0, "rf_count must be a positive integer"
            self.rf_count = int(rf_count)
            # Trainable frequency matrix and phase
            self.freq = nn.Parameter(torch.randn(self.rf_count, self.input_dim) * fourier_init_scale)  # [rf, d+1]
            self.phase = nn.Parameter(torch.rand(self.rf_count) * 2 * math.pi)                          # [rf]
            self._build_post_fourier_stack()
        else:
            # Standard MLP
            self.act = get_activation(self.non_fourier_activation)
            layers: list[nn.Module] = []
            if self.hidden_depth <= 0:
                self.inp = nn.Linear(self.input_dim, self.output_dim)
                self.out = None
                self.hidden = None
            else:
                layers.append(nn.Linear(self.input_dim, self.hidden_width))
                for _ in range(self.hidden_depth - 1):
                    layers.append(nn.Linear(self.hidden_width, self.hidden_width))
                self.hidden = nn.ModuleList(layers)
                self.out = nn.Linear(self.hidden_width, self.output_dim)

    # def _build_post_fourier_stack(self):
    #     base_dim = 2 * self.rf_count  # cos+sin
    #     if self.hidden_depth <= 0:
    #         self.project = None
    #         self.hidden = None
    #         self.out = nn.Linear(base_dim, self.output_dim)
    #     else:
    #         self.project = nn.Linear(base_dim, self.hidden_width)
    #         hidden_layers: list[nn.Module] = []
    #         for _ in range(self.hidden_depth - 1):
    #             hidden_layers.append(nn.Linear(self.hidden_width, self.hidden_width))
    #         self.hidden = nn.ModuleList(hidden_layers)
    #         self.out = nn.Linear(self.hidden_width, self.output_dim)
    #         self.act = torch.tanh


    def _build_post_fourier_stack(self):
        base_dim = 2 * self.rf_count  # cos+sin
        
        if self.hidden_depth <= 0:
            self.project = None
            self.hidden = None
            self.out = nn.Linear(base_dim, self.output_dim)
            init_xavier(self.out)
        else:
            self.project = nn.Linear(base_dim, self.hidden_width)
            init_xavier(self.project)
    
            hidden_layers = []
            for _ in range(self.hidden_depth - 1):
                lay = nn.Linear(self.hidden_width, self.hidden_width)
                init_xavier(lay)
                hidden_layers.append(lay)
            self.hidden = nn.ModuleList(hidden_layers)
    
            self.out = nn.Linear(self.hidden_width, self.output_dim)
            init_xavier(self.out)
    
            self.act = torch.tanh


    def _first_mapping(self, xt: torch.Tensor) -> torch.Tensor:
        x_scaled = xt * self.input_scales  # [N, d+1]
        if self.use_fourier:
            proj = x_scaled @ self.freq.t() + self.phase  # [N, rf]
            cosf = torch.cos(proj)
            sinf = torch.sin(proj)
            scale = (self.rf_count) ** (-0.5)
            h0 = scale * torch.cat([cosf, sinf], dim=1)   # [N, 2*rf]
            return h0
        else:
            if self.hidden_depth <= 0:
                return x_scaled
            z = self.hidden[0](x_scaled)
            return self.act(z)

    def forward(self, xt: torch.Tensor) -> torch.Tensor:
        h = self._first_mapping(xt)
        if self.use_fourier:
            if self.hidden_depth > 0:
                h = self.act(self.project(h))
                for layer in self.hidden:
                    h = self.act(layer(h))
            raw = self.out(h)
        else:
            if self.hidden_depth <= 0:
                raw = self.inp(h)  # h is already scaled xt
            else:
                for layer in self.hidden[1:]:
                    h = self.act(layer(h))
                raw = self.out(h)

        if self.positive == 'square':
            return raw ** 2
        elif self.positive == 'softplus':
            return torch.nn.functional.softplus(raw, beta=1.0)
        elif self.positive == 'relu':
            return torch.relu(raw)
        elif self.positive == 'none':
            return raw
        else:
            raise ValueError("positive must be 'square', 'softplus', 'relu', or 'none'")


# -------------------------
# Barenblatt profile
# -------------------------
def barenblatt_ic_dd(
    x: torch.Tensor,           # [N, d] spatial points
    L_vec: list[float],        # domain lengths
    m: float,
    t0: float,
    support_fraction: float = SUPPORT_FRAC,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    Returns u(0,x) := Barenblatt(t0, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)  # generalizes 1D formula

    R0 = support_fraction * (min(L_vec) / 2.0)  # safe inside the domain
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t0 ** (2.0 * beta))
    z_clamped = torch.clamp(z, min=0.0)
    u_t0 = (t0 ** (-alpha)) * (z_clamped ** (1.0 / (m - 1.0)))
    return u_t0


def barenblatt_true_shifted_dd(
    t_model: torch.Tensor,     # [N, 1] or scalar
    x: torch.Tensor,           # [N, d]
    L_vec: list[float],
    m: float,
    t0: float,
    support_fraction: float = SUPPORT_FRAC,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    u_true(x, t_model) = Barenblatt(t0 + t_model, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 2.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    if t_phys.ndim == 0:
        t_phys = t_phys * torch.ones(x.shape[0], 1, device=x.device, dtype=x.dtype)

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u


# -------------------------
# Collocation helpers (many only used for custom collocation)
# -------------------------
def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L


def sample_near_front_nd(n: int, L_vec: list[float], t0: float, T: float, m: float, band_rel: float) -> torch.Tensor:
    """
    Sample near the expanding free boundary radius R(t) ≈ R0 * ((t0+t)/t0)^beta.
    """
    d = len(L_vec)
    lam = 2.0 + d * (m - 1.0)
    beta = 1.0 / lam
    Lmin = min(L_vec)
    R0 = SUPPORT_FRAC * (Lmin / 2.0)  # slightly smaller than IC choice
    t = (torch.rand(n, 1, device=DEVICE) ** 3) * T
    R = R0 * ((t0 + t) / t0) ** beta

    # directions ~ uniform on sphere via Gaussian normalization
    dirs = torch.randn(n, d, device=DEVICE)
    dirs = dirs / (dirs.norm(dim=1, keepdim=True) + 1e-12)

    delta = band_rel * Lmin
    radii = R.squeeze(1) + (2 * torch.rand(n, device=DEVICE) - 1.0) * delta
    radii = radii.clamp(min=0.0)

    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    x = x0 + dirs * radii.view(-1, 1)

    # clamp to box
    L = _to_tensor(L_vec).view(1, d)
    x = torch.clamp(x, min=torch.zeros(1, d, device=DEVICE), max=L)
    xt = torch.cat([x, t], dim=1)
    return xt



def build_collocation_sets_nd(use_custom: bool, d: int):
    # PDE points
    if use_custom:
        n_front = int(0.4 * N_PDE)
        n_bulk = N_PDE - n_front
        pde_xt_front = sample_near_front_nd(n_front, L_VEC, T0_IC, T_DOMAIN, M_EXPONENT, FRONT_BAND_REL)
        x_bulk = rand_in_box(n_bulk, L_VEC)
        t_bulk = (torch.rand(n_bulk, 1, device=DEVICE) ** 3) * T_DOMAIN
        pde_xt_bulk = torch.cat([x_bulk, t_bulk], dim=1)
        pde_xt = torch.cat([pde_xt_front, pde_xt_bulk], dim=0)
    else:
        x_bulk = rand_in_box(N_PDE, L_VEC)
        t_bulk = torch.rand(N_PDE, 1, device=DEVICE) * T_DOMAIN
        pde_xt = torch.cat([x_bulk, t_bulk], dim=1)

    # IC points @ t=0 (use Barenblatt at t0 for initial state, sampled over space)
    # ic_x = rand_in_box(N_IC, L_VEC)
    # ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    # ic_xt = torch.cat([ic_x, ic_t], dim=1)
    # u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)

        # IC @ t=0
    # 1. Get the correct radius
    R0_true = SUPPORT_FRAC * (min(L_VEC) / 2.0)
    
    # 2. Sample from a d-D Gaussian
    ic_x_unit_ball = torch.randn(N_IC, d, device=DEVICE)
    ic_x_unit_ball = ic_x_unit_ball / (ic_x_unit_ball.norm(dim=1, keepdim=True) + 1e-12)
    
    # 3. Give them random radii *within* R0
    radii = (torch.rand(N_IC, 1, device=DEVICE) ** (1/d)) * R0_true
    
    # 4. Center and clamp to box (just in case)
    x0_center = _to_tensor([0.5 * L for L in L_VEC]).view(1, d)
    ic_x = x0_center + ic_x_unit_ball * radii
    
    # Clamp to be safe, though most should be inside [0,1]^d
    L_tensor = _to_tensor(L_VEC).view(1, d)
    ic_x = torch.clamp(ic_x, min=torch.zeros(1, d, device=DEVICE), max=L_tensor)

    ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    ic_xt = torch.cat([ic_x, ic_t], dim=1)
    
    # Now this will be non-zero!
    u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)

    # BC points on each face: x_i = 0 and x_i = L_i
    faces = []
    per_face = max(1, N_BC // (2 * d))
    for i in range(d):
        # Other coordinates uniform
        x_rest = rand_in_box(per_face, L_VEC)
        x0_face = x_rest.clone()
        xL_face = x_rest.clone()
        x0_face[:, i] = 0.0
        xL_face[:, i] = L_VEC[i]
        t_face = torch.rand(per_face, 1, device=DEVICE) * T_DOMAIN
        faces.append(torch.cat([x0_face, t_face], dim=1))
        faces.append(torch.cat([xL_face, t_face], dim=1))
    bc_xt = torch.cat(faces, dim=0)

    return pde_xt, ic_xt, u_ic_true, bc_xt


# -------------------------
# PDE residual 
# -------------------------
def pde_residual_autograd_nd(model: nn.Module, xt: torch.Tensor, m_exp: float, d: int):
    """
    Returns: R, u, grad_u (only spatial components), grad_v (only spatial components)
    where v = u^m
    """
    xt_req = xt.clone().detach().requires_grad_(True)    # [N, d+1]
    u = model(xt_req)                                    # [N, 1]

    # Gradient of u w.r.t. all inputs (x_1,...,x_d,t)
    grads_u = torch.autograd.grad(u, xt_req, torch.ones_like(u), create_graph=True)[0]  # [N, d+1]
    u_x = grads_u[:, :d]   # [N, d]
    u_t = grads_u[:, d:d+1]

    v = u.pow(m_exp)       # [N, 1]
    grads_v = torch.autograd.grad(v, xt_req, torch.ones_like(v), create_graph=True)[0]  # [N, d+1]
    v_x = grads_v[:, :d]   # [N, d]

    # Laplacian of v 
    v_xx_terms = []
    for i in range(d):
        vi = v_x[:, i:i+1]  # [N,1]
        grads_vi = torch.autograd.grad(vi, xt_req, torch.ones_like(vi), create_graph=True)[0]  # [N, d+1]
        v_xx_terms.append(grads_vi[:, i:i+1])  # second derivative w.r.t. x_i
    v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N,1]

    R = u_t - v_lap
    return R, u, u_x, v_x


def compute_losses(model, pde_xt, ic_xt, u_ic_true, bc_xt, d, mse):
    # PDE residual
    R_pde, _, _, _ = pde_residual_autograd_nd(model, pde_xt, M_EXPONENT, d)
    loss_pde = mse(R_pde, torch.zeros_like(R_pde))

    # IC
    u_ic_pred = model(ic_xt)
    loss_ic = mse(u_ic_pred, u_ic_true)

    # BC zero-flux
    _, _, _, v_x_bc = pde_residual_autograd_nd(model, bc_xt, M_EXPONENT, d)
    per_face = max(1, N_BC // (2 * d))
    losses = []
    offset = 0
    for i in range(d):
        sel = slice(offset, offset + per_face)  # x_i = 0 face
        losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
        offset += per_face
        sel = slice(offset, offset + per_face)  # x_i = L_i face
        losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
        offset += per_face
    loss_bc = torch.stack(losses).mean()

    loss = LAMBDA_PDE * loss_pde + LAMBDA_IC * loss_ic + LAMBDA_BC * loss_bc
    return loss, loss_pde, loss_ic, loss_bc





# =========================
# Multi-run Training & Evaluation
# =========================
import random
from datetime import datetime

BASE_SEED = 12345      

def _set_seeds(s: int):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)

for SPATIAL_DIM in D_LIST:
    # ---- dimension-specific globals ----
    L_VEC = [1.0] * SPATIAL_DIM
    INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

    hp = hyperparams_for_dim(SPATIAL_DIM)
    RF_COUNT = hp["RF_COUNT"]
    N_PDE = hp["N_PDE"]
    N_IC = hp["N_IC"]
    N_BC = hp["N_BC"]
    T_DOMAIN = hp["T_DOMAIN"]   
    HIDDEN_WIDTH = hp["HIDDEN_WIDTH"]

    if SPATIAL_DIM >= 4:
        LAMBDA_IC = 500.0
        SUPPORT_FRAC = 0.80
    
    LOG_FILE = f"pinn_multirun_{SPATIAL_DIM}D_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

    # Storage for metrics **for this dimension**
    spacetime_errors = []
    finaltime_errors = []
    run_durations = []

    print(f"\n\n=== Starting runs for d = {SPATIAL_DIM} ===")

    with open(LOG_FILE, "w", encoding="utf-8") as logf:

        def _log(msg: str):
            print(msg)
            logf.write(msg + "\n")
            logf.flush()

        # Log configuration header
        _log(f"== {SPATIAL_DIM}D PINN for PME (u_t - Δ(u^m) = 0) ==")
        _log(f"m (M_EXPONENT)         : {M_EXPONENT}")
        _log(f"T0_IC / T_DOMAIN       : {T0_IC} / {T_DOMAIN}")
        _log(f"L_VEC                  : {L_VEC}")
        _log(f"USE_FOURIER_FEATURES  : {USE_FOURIER_FEATURES}")
        _log(f"RF_COUNT               : {RF_COUNT}")
        _log(f"HIDDEN_DEPTH / WIDTH  : {HIDDEN_DEPTH} / {HIDDEN_WIDTH}")
        _log(f"POSITIVE_HEAD         : {POSITIVE_HEAD}")
        _log(f"INPUT_SCALES          : {INPUT_SCALES}")
        _log(f"USE_CUSTOM_COLLOCATION: {USE_CUSTOM_COLLOCATION}")
        _log(f"LEARNING_RATE         : {LEARNING_RATE}")
        _log(f"NUM_EPOCHS            : {NUM_EPOCHS}")
        _log(f"N_PDE / N_IC / N_BC   : {N_PDE} / {N_IC} / {N_BC}")
        _log(f"N_EVAL (MC samples)   : 20000")
        _log(f"M_RUNS                : {M_RUNS}")
        _log("")

        for run_idx in range(1, M_RUNS + 1):
            _log(f"--- Run {run_idx}/{M_RUNS} for d={SPATIAL_DIM} ---")
            seed = BASE_SEED + run_idx
            _set_seeds(seed)
            _log(f"Seed: {seed}")

            # Fresh collocation sets
            pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(
                USE_CUSTOM_COLLOCATION, SPATIAL_DIM
            )

            # Fresh model &optimizern
            model = PINNnD(
                SPATIAL_DIM,
                use_fourier=USE_FOURIER_FEATURES,
                rf_count=RF_COUNT,
                fourier_init_scale=FOURIER_INIT_SCALE,
                non_fourier_activation=NON_FOURIER_ACTIVATION,
                hidden_depth=HIDDEN_DEPTH,
                hidden_width=HIDDEN_WIDTH,
                positive=POSITIVE_HEAD,
                input_scales=INPUT_SCALES,
                seed=seed,
            ).to(DEVICE)

            # Adam for initial epochs
            adam_optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

            # L-BFGS for last M epochs
            M_LBFGS = min(LBFGS_LAST_M_EPOCHS, NUM_EPOCHS)
            use_lbfgs = M_LBFGS > 0
            adam_epochs = NUM_EPOCHS - M_LBFGS if use_lbfgs else NUM_EPOCHS

            lbfgs_optimizer = None
            if use_lbfgs:
                lbfgs_optimizer = torch.optim.LBFGS(
                    model.parameters(),
                    lr=LBFGS_LR,
                    max_iter=LBFGS_MAX_ITER,
                    history_size=50,
                    line_search_fn="strong_wolfe",
                )

            mse = nn.MSELoss()


            # ---- Training ----
            _log("Training...")
            t0 = time.time()

            # helper: full loss computation
            def compute_loss_full():
                return compute_losses(
                    model, pde_xt, ic_xt, u_ic_true, bc_xt, SPATIAL_DIM, mse
                )

            ###########################################
            # 1) Adam Phase
            ###########################################
            for epoch in range(1, adam_epochs + 1):
                adam_optimizer.zero_grad(set_to_none=True)
                total, lp, li, lb = compute_loss_full()
                total.backward()
                adam_optimizer.step()

                if epoch == 1 or epoch % 1000 == 0 or epoch == adam_epochs:
                    _log(
                        f"[Adam] Epoch {epoch}/{NUM_EPOCHS} "
                        f"| Total {total.item():.3e} "
                        f"| PDE {lp.item():.3e} "
                        f"| IC {li.item():.3e} "
                        f"| BC {lb.item():.3e}"
                    )

            ###########################################
            # 2) L-BFGS Phase
            ###########################################
            metrics = {
                "total": None,
                "pde":   None,
                "ic":    None,
                "bc":    None,
            }

            if use_lbfgs:
                for epoch in range(adam_epochs + 1, NUM_EPOCHS + 1):

                    def closure():
                        lbfgs_optimizer.zero_grad(set_to_none=True)
                        total, lp, li, lb = compute_loss_full()
                        total.backward()

                        # store for logging
                        metrics["total"] = total.detach()
                        metrics["pde"]   = lp.detach()
                        metrics["ic"]    = li.detach()
                        metrics["bc"]    = lb.detach()
                        return total

                    lbfgs_optimizer.step(closure)

                    if (
                        epoch == adam_epochs + 1
                        or epoch % 100 == 0
                        or epoch == NUM_EPOCHS
                    ):
                        _log(
                            f"[LBFGS] Epoch {epoch}/{NUM_EPOCHS} "
                            f"| Total {metrics['total'].item():.3e} "
                            f"| PDE {metrics['pde'].item():.3e} "
                            f"| IC {metrics['ic'].item():.3e} "
                            f"| BC {metrics['bc'].item():.3e}"
                        )

            dur = time.time() - t0
            run_durations.append(dur)
            _log(f"Finished run {run_idx} in {dur:.2f}s")



            # ---- Evaluation ----
            model.eval()
            N_EVAL = 20000

            # space–time error
            x_eval = rand_in_box(N_EVAL, L_VEC)
            t_eval = torch.rand(N_EVAL, 1, device=DEVICE) * T_DOMAIN
            xt_eval = torch.cat([x_eval, t_eval], dim=1)

            with torch.no_grad():
                u_pred_eval = model(xt_eval)
                u_true_eval = barenblatt_true_shifted_dd(t_eval, x_eval, L_VEC, M_EXPONENT, T0_IC, support_fraction=SUPPORT_FRAC)

            num = torch.sum((u_pred_eval - u_true_eval) ** 2)
            den = torch.sum(u_true_eval ** 2) + 1e-16
            rel_L2_spacetime = torch.sqrt(num / den).item()
            spacetime_errors.append(rel_L2_spacetime)
            _log(f"Rel L2 (space–time, MC): {rel_L2_spacetime:.6e}")

            # final-time error
            x_evalT = rand_in_box(N_EVAL, L_VEC)
            tT = torch.full((N_EVAL, 1), T_DOMAIN, device=DEVICE)
            with torch.no_grad():
                u_pred_T = model(torch.cat([x_evalT, tT], dim=1))
                u_true_T = barenblatt_true_shifted_dd(
                    tT, x_evalT, L_VEC, M_EXPONENT, T0_IC,
                    support_fraction=SUPPORT_FRAC
                )

            numT = torch.sum((u_pred_T - u_true_T) ** 2)
            denT = torch.sum(u_true_T ** 2) + 1e-16
            rel_L2_final = torch.sqrt(numT / denT).item()
            finaltime_errors.append(rel_L2_final)
            _log(f"Rel L2 (final time T={T_DOMAIN}): {rel_L2_final:.6e}")
            _log("")

            # optional: free memory
            del model
            if DEVICE.type == "cuda":
                torch.cuda.empty_cache()

        # ---- Summary stats for this dimension ----
        spacetime_np = np.array(spacetime_errors, dtype=np.float64)
        finaltime_np = np.array(finaltime_errors, dtype=np.float64)

        st_mean = float(spacetime_np.mean())
        st_std  = float(spacetime_np.std(ddof=1)) if len(spacetime_np) > 1 else 0.0
        ft_mean = float(finaltime_np.mean())
        ft_std  = float(finaltime_np.std(ddof=1)) if len(finaltime_np) > 1 else 0.0

        _log("=== Summary over runs ===")
        _log(f"Space–time rel L2: mean={st_mean:.6e}, std={st_std:.6e}")
        _log(f"Final-time rel L2: mean={ft_mean:.6e}, std={ft_std:.6e}")
        _log(f"Durations (s)    : {', '.join(f'{dur_i:.2f}' for dur_i in run_durations)}")
        _log(f"Saved full log to: {LOG_FILE}")




