from __future__ import annotations

"""
Multi-Dimensional Porous Medium Equation (PME) solved by a RaNN.
This script trains a RaNN to solve the PME five times, and then saves the results to a .tvt file.
"""

import math
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# -------------------------
# Global Config
# -------------------------
torch.manual_seed(40)
np.random.seed(40)

# Problem params
SPATIAL_DIM = 1                    
M_EXPONENT = 2.0                   # m in PME (m > 1)
T0_IC = 1e-2                       # small positive time for IC (avoid singularity)
T_DOMAIN = 0.025                   # T in [0, T]
M_RUNS = 5
N_EVAL = 20000

# Domain Ω = ∏ [0, L_k]
L_VEC = [1.0] * SPATIAL_DIM        # lengths per dimension, defaults to [1.0]*d

# RaNN params
USE_FOURIER_FEATURES = False        # False to use a non-Fourier random feature map
NON_FOURIER_ACTIVATION = 'tanh'    # used if USE_FOURIER_FEATURES=False
NUM_RANDOM_FEATURES = 2000         # total hidden features; for Fourier must be EVEN (cos+sin)
FEATURE_SCALE = 10.0               # init scale for random weights
POSITIVE_HEAD = 'square'           # 'square' | 'softplus' | 'relu' | 'none'

# Input anisotropic scaling (per spatial dim, then time)
INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

# Collocation 
USE_CUSTOM_COLLOCATION = False      # False for uniform collocation
FRONT_BAND_REL = 0.02              

# Training params
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5000
N_PDE = 4000                       
N_IC = 2000
N_BC = 1000                        

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 200.0
LAMBDA_BC = 1.0

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")


# -------------------------
# Utilities
# -------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return torch.tanh
    if name == 'relu':
        return torch.relu
    if name == 'sigmoid':
        return torch.sigmoid
    raise ValueError(f"Unsupported activation: {name}")


def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)


# -------------------------
# RaNN definition
# -------------------------
class RFNNnD(nn.Module):
    def __init__(
        self,
        spatial_dim: int,
        num_random_features: int,
        feature_scale: float = 1.0,
        *,
        use_fourier: bool = True,
        non_fourier_activation: str = 'tanh',
        positive: str = 'square',
        input_scales: list[float] | Tuple[float, ...] = None,
        seed: int = 12345,
    ) -> None:
        super().__init__()
        torch.manual_seed(seed)

        self.d = spatial_dim
        self.input_dim = self.d + 1    # (x1,...,xd, t)
        self.output_dim = 1
        self.num_random_features = int(num_random_features)
        self.use_fourier = bool(use_fourier)
        self.positive = positive
        self.non_fourier_activation = non_fourier_activation.lower()

        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales).view(1, -1))

        if self.use_fourier:
            assert self.num_random_features % 2 == 0, "For Fourier features, NUM_RANDOM_FEATURES must be even."
            rf = self.num_random_features // 2
            # Fixed random frequencies & phases 
            self.random_weights = torch.randn(rf, self.input_dim, device=DEVICE) * feature_scale  # [rf, d+1]
            self.random_biases = torch.rand(rf, device=DEVICE) * 2 * math.pi                      # [rf]
        else:
            # Fixed random affine map (frozen)
            self.random_weights = torch.randn(self.num_random_features, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(self.num_random_features, device=DEVICE)

        self.random_weights = nn.Parameter(self.random_weights, requires_grad=False)
        self.random_biases = nn.Parameter(self.random_biases, requires_grad=False)

        # Trainable readout
        self.output_layer = nn.Linear(self.num_random_features, self.output_dim, device=DEVICE)

    def _feature_map(self, xt: torch.Tensor) -> torch.Tensor:
        x_scaled = xt * self.input_scales  # [N, d+1]
        proj = x_scaled @ self.random_weights.t()

        if self.use_fourier:
            cosf = torch.cos(proj + self.random_biases)  # [N, rf]
            sinf = torch.sin(proj + self.random_biases)  # [N, rf]
            scale = (self.num_random_features // 2) ** (-0.5)
            return scale * torch.cat([cosf, sinf], dim=1)
        else:
            act = get_activation(self.non_fourier_activation)
            return act(proj + self.random_biases)

    def forward(self, xt: torch.Tensor) -> torch.Tensor:
        phi = self._feature_map(xt)           # [N, M]
        raw = self.output_layer(phi)          # [N, 1]

        if self.positive == 'square':
            return raw ** 2
        elif self.positive == 'softplus':
            return torch.nn.functional.softplus(raw, beta=1.0)
        elif self.positive == 'relu':
            return torch.relu(raw)
        elif self.positive == 'none':
            return raw
        else:
            raise ValueError("positive must be 'square', 'softplus', 'relu', or 'none'")


# -------------------------
# Barenblatt in d-D
# -------------------------
def barenblatt_ic_dd(
    x: torch.Tensor,           # [N, d] spatial points
    L_vec: list[float],        # domain lengths
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 1.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t0 ** (2.0 * beta))
    z_clamped = torch.clamp(z, min=0.0)
    u_t0 = (t0 ** (-alpha)) * (z_clamped ** (1.0 / (m - 1.0)))
    return u_t0


def barenblatt_true_shifted_dd(
    t_model: torch.Tensor,     # [N, 1] or scalar
    x: torch.Tensor,           # [N, d]
    L_vec: list[float],
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 1.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    if t_phys.ndim == 0:
        t_phys = t_phys * torch.ones(x.shape[0], 1, device=x.device, dtype=x.dtype)

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u


# -------------------------
# Collocation Builders
# -------------------------
def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L


def sample_near_front_nd(n: int, L_vec: list[float], t0: float, T: float, m: float, band_rel: float) -> torch.Tensor:
    """Near-front sampling in d-D (radius heuristic)."""
    d = len(L_vec)
    lam = 2.0 + d * (m - 1.0)
    beta = 1.0 / lam
    Lmin = min(L_vec)
    R0 = 0.25 * (Lmin / 2.0)
    t = (torch.rand(n, 1, device=DEVICE) ** 3) * T
    R = R0 * ((t0 + t) / t0) ** beta

    dirs = torch.randn(n, d, device=DEVICE)
    dirs = dirs / (dirs.norm(dim=1, keepdim=True) + 1e-12)

    delta = band_rel * Lmin
    radii = R.squeeze(1) + (2 * torch.rand(n, device=DEVICE) - 1.0) * delta
    radii = radii.clamp(min=0.0)

    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    x = x0 + dirs * radii.view(-1, 1)

    L = _to_tensor(L_vec).view(1, d)
    x = torch.clamp(x, min=torch.zeros(1, d, device=DEVICE), max=L)
    return torch.cat([x, t], dim=1)



def build_collocation_sets_nd(use_custom: bool, d: int):
    if use_custom:
        n_front = int(0.4 * N_PDE)
        n_bulk = N_PDE - n_front
        pde_xt_front = sample_near_front_nd(n_front, L_VEC, T0_IC, T_DOMAIN, M_EXPONENT, FRONT_BAND_REL)
        x_bulk = rand_in_box(n_bulk, L_VEC)
        t_bulk = (torch.rand(n_bulk, 1, device=DEVICE) ** 3) * T_DOMAIN
        pde_xt_bulk = torch.cat([x_bulk, t_bulk], dim=1)
        pde_xt = torch.cat([pde_xt_front, pde_xt_bulk], dim=0)
    else:
        x_bulk = rand_in_box(N_PDE, L_VEC)
        t_bulk = torch.rand(N_PDE, 1, device=DEVICE) * T_DOMAIN
        pde_xt = torch.cat([x_bulk, t_bulk], dim=1)

    # IC @ t=0
    ic_x = rand_in_box(N_IC, L_VEC)
    ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    ic_xt = torch.cat([ic_x, ic_t], dim=1)
    u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

    # BC on face
    faces = []
    per_face = max(1, N_BC // (2 * d))
    for i in range(d):
        x_rest = rand_in_box(per_face, L_VEC)
        x0_face = x_rest.clone(); x0_face[:, i] = 0.0
        xL_face = x_rest.clone(); xL_face[:, i] = L_VEC[i]
        t_face = torch.rand(per_face, 1, device=DEVICE) * T_DOMAIN
        faces.append(torch.cat([x0_face, t_face], dim=1))
        faces.append(torch.cat([xL_face, t_face], dim=1))
    bc_xt = torch.cat(faces, dim=0)

    return pde_xt, ic_xt, u_ic_true, bc_xt


# -------------------------
# PDE residuals
# -------------------------
def pde_residual_autograd_nd(model: nn.Module, xt: torch.Tensor, m_exp: float, d: int):
    """Generic autograd residual."""
    xt_req = xt.clone().detach().requires_grad_(True)    # [N, d+1]
    u = model(xt_req)                                    # [N, 1]

    grads_u = torch.autograd.grad(u, xt_req, torch.ones_like(u), create_graph=True)[0]  # [N, d+1]
    u_t = grads_u[:, d:d+1]
    v = u.pow(m_exp)

    grads_v = torch.autograd.grad(v, xt_req, torch.ones_like(v), create_graph=True)[0]
    v_x = grads_v[:, :d]  # [N, d]

    # Laplacian
    v_xx_terms = []
    for i in range(d):
        vi = v_x[:, i:i+1]
        grads_vi = torch.autograd.grad(vi, xt_req, torch.ones_like(vi), create_graph=True)[0]
        v_xx_terms.append(grads_vi[:, i:i+1])
    v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N,1]

    R = u_t - v_lap
    return R, u, v_x


def pde_residual_fast_nd(model: RFNNnD, xt: torch.Tensor, m_exp: float, d: int):
    """
    Closed-form residual for Fourier features with 'square' head.
    Assumes: model.use_fourier == True and model.positive == 'square'.
    """
    assert model.use_fourier, "fast residual requires Fourier features"
    assert model.positive == 'square', "fast residual currently assumes positive='square'"
    # Shorthands
    rf = model.num_random_features // 2
    W = model.random_weights[:rf, :]  # [rf, d+1]
    b = model.random_biases[:rf]      # [rf]
    scales = model.input_scales.view(-1)  # [d+1]

    scale_rf = (rf) ** (-0.5)

    # Project & basis
    proj = (xt * model.input_scales) @ W.t()  # [N, rf]
    cosv = torch.cos(proj + b)                # [N, rf]
    sinv = torch.sin(proj + b)                # [N, rf]

    # Readout weights split
    Wout = model.output_layer.weight.squeeze(0)  # [2*rf]
    a = scale_rf * Wout[:rf]                      # [rf]
    c = scale_rf * Wout[rf:]                      # [rf]
    bias = model.output_layer.bias                # [1]

    # raw field
    raw = (cosv @ a) + (sinv @ c) + bias.squeeze()  # [N]

    # derivatives per coordinate
    raw_t = None
    raw_xx_terms = []
    for k in range(d + 1):
        proj_k = scales[k] * W[:, k]  # [rf]
        raw_k = (-sinv * proj_k) @ a + (cosv * proj_k) @ c  # [N]
        raw_kk = (-(cosv * (proj_k ** 2)) @ a) + (-(sinv * (proj_k ** 2)) @ c)  # [N]

        if k == d:  # time
            raw_t = raw_k
        else:
            raw_xx_terms.append(2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk)  

    # u and its derivatives
    u = raw * raw                      # [N]
    u_t = 2.0 * raw * raw_t            # [N]
    u_xx_sum = torch.stack(raw_xx_terms, dim=0).sum(dim=0)  # sum over spatial dims

    # v = u^m and Laplacian of v
    if m_exp == 1.0:
        v_lap = u_xx_sum
    else:
        u_safe = u + 0.0
        u_m1 = u_safe.pow(m_exp - 1.0)
        u_m2 = u_safe.pow(m_exp - 2.0)

        # We also need u_x per dim for the (m-1) term; recompute raw_k per dim
        v_xx_terms = []
        for k in range(d):
            proj_k = scales[k] * W[:, k]
            raw_k = (-sinv * proj_k) @ a + (cosv * proj_k) @ c  # [N]
            raw_kk = (-(cosv * (proj_k ** 2)) @ a) + (-(sinv * (proj_k ** 2)) @ c)  # [N]

            u_x_k = 2.0 * raw * raw_k
            u_xx_k = 2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk
            v_xx_k = m_exp * (m_exp - 1.0) * u_m2 * (u_x_k * u_x_k) + m_exp * u_m1 * u_xx_k
            v_xx_terms.append(v_xx_k)

        v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N]

    R = u_t - v_lap
    return R.unsqueeze(1), u.unsqueeze(1)


# Simple feature cache for Fourier fast path
class FeatureCache:
    def __init__(self, model, xt, d):
        rf = model.num_random_features // 2
        W = model.random_weights[:rf, :]
        b = model.random_biases[:rf]
        self.scales = model.input_scales.view(-1)
        self.proj = (xt * model.input_scales) @ W.t()  # [N, rf]
        self.cosv = torch.cos(self.proj + b)           # [N, rf]
        self.sinv = torch.sin(self.proj + b)           # [N, rf]
        self.W = W
        self.rf = rf
        self.d = d

# Cached fast residual using precomputed cosv/sinv
def pde_residual_fast_nd_cached(model, cache, m_exp):
    rf, d = cache.rf, cache.d
    cosv, sinv, W, scales = cache.cosv, cache.sinv, cache.W, cache.scales
    scale_rf = (rf)**(-0.5)
    Wout = model.output_layer.weight.squeeze(0)
    a = scale_rf * Wout[:rf]; c = scale_rf * Wout[rf:]
    bias = model.output_layer.bias

    raw = (cosv @ a) + (sinv @ c) + bias           # [N]
    u   = raw * raw
    # time deriv
    proj_t = scales[d] * W[:, d]
    raw_t  = (-sinv * proj_t) @ a + (cosv * proj_t) @ c
    u_t    = 2.0 * raw * raw_t

    # spatial Laplacian of v = u^m
    u_safe = u.clamp_min(0.0)
    u_m1 = u_safe.pow(m_exp - 1.0)
    u_m2 = u_safe.pow(m_exp - 2.0) if m_exp != 1.0 else None
    v_xx_terms = []
    for k in range(d):
        proj_k = scales[k] * W[:, k]
        raw_k  = (-sinv * proj_k) @ a + (cosv * proj_k) @ c
        raw_kk = (-(cosv * (proj_k**2)) @ a) + (-(sinv * (proj_k**2)) @ c)
        u_xk   = 2.0 * raw * raw_k
        u_xxk  = 2.0 * (raw_k * raw_k) + 2.0 * raw * raw_kk
        if m_exp == 1.0:
            v_xxk = u_xxk
        else:
            v_xxk = m_exp*(m_exp-1.0)*u_m2*(u_xk*u_xk) + m_exp*u_m1*u_xxk
        v_xx_terms.append(v_xxk)
    v_lap = torch.stack(v_xx_terms, 0).sum(0)
    R = u_t - v_lap
    return R.unsqueeze(1), u.unsqueeze(1)



def residual_and_flux(model: RFNNnD, xt: torch.Tensor, m_exp: float, d: int):
    if model.use_fourier and model.positive == 'square':
        return pde_residual_fast_nd(model, xt, m_exp, d)
    else:

        R, u, _v_x = pde_residual_autograd_nd(model, xt, m_exp, d)
        return R, u


# -------------------------
# Build data, model, optimizer
# -------------------------
pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

model = RFNNnD(
    SPATIAL_DIM,
    NUM_RANDOM_FEATURES,
    feature_scale=FEATURE_SCALE,
    use_fourier=USE_FOURIER_FEATURES,
    non_fourier_activation=NON_FOURIER_ACTIVATION,
    positive=POSITIVE_HEAD,
    input_scales=INPUT_SCALES,
    seed=12345,
).to(DEVICE)

optimizer = torch.optim.Adam(model.output_layer.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()


# =========================
# Multi-run Training & Evaluation (RaNN)
# =========================
import random
from datetime import datetime

BASE_SEED = 12345            
LOG_FILE = f"rann_multirun_{SPATIAL_DIM}D_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

def _set_seeds(s: int):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)

# Storage for metrics
spacetime_errors = []
finaltime_errors = []
run_durations = []

with open(LOG_FILE, "w", encoding="utf-8") as logf:

    def _log(msg: str):
        print(msg)
        logf.write(msg + "\n")
        logf.flush()

    # Log configuration header
    _log(f"== {SPATIAL_DIM}D RaNN for PME (u_t - Δ(u^m) = 0) ==")
    _log(f"m (M_EXPONENT)         : {M_EXPONENT}")
    _log(f"T0_IC / T_DOMAIN       : {T0_IC} / {T_DOMAIN}")
    _log(f"L_VEC                  : {L_VEC}")
    _log(f"USE_FOURIER_FEATURES  : {USE_FOURIER_FEATURES}")
    _log(f"NUM_RANDOM_FEATURES   : {NUM_RANDOM_FEATURES}")
    _log(f"FEATURE_SCALE         : {FEATURE_SCALE}")
    _log(f"POSITIVE_HEAD         : {POSITIVE_HEAD}")
    _log(f"INPUT_SCALES          : {INPUT_SCALES}")
    _log(f"USE_CUSTOM_COLLOCATION: {USE_CUSTOM_COLLOCATION}")
    _log(f"LEARNING_RATE         : {LEARNING_RATE}")
    _log(f"NUM_EPOCHS            : {NUM_EPOCHS}")
    _log(f"N_PDE / N_IC / N_BC   : {N_PDE} / {N_IC} / {N_BC}")
    _log(f"N_EVAL (MC samples)   : {N_EVAL}")
    _log(f"M_RUNS                : {M_RUNS}")
    _log("")

    for run_idx in range(1, M_RUNS + 1):
        _log(f"--- Run {run_idx}/{M_RUNS} ---")
        seed = BASE_SEED + run_idx
        _set_seeds(seed)
        _log(f"Seed: {seed}")

        # Fresh collocation sets each run
        pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

        # Restarrt model & optimizer
        model = RFNNnD(
            SPATIAL_DIM,
            NUM_RANDOM_FEATURES,
            feature_scale=FEATURE_SCALE,
            use_fourier=USE_FOURIER_FEATURES,
            non_fourier_activation=NON_FOURIER_ACTIVATION,
            positive=POSITIVE_HEAD,
            input_scales=INPUT_SCALES,
            seed=seed,
        ).to(DEVICE)
        optimizer = torch.optim.Adam(model.output_layer.parameters(), lr=LEARNING_RATE)
        mse = nn.MSELoss()

        # Build caches once (after building collocation sets)
        if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
            pde_cache = FeatureCache(model, pde_xt, SPATIAL_DIM)
            bc_cache  = FeatureCache(model, bc_xt,  SPATIAL_DIM)

        # ---- Training ----
        _log("Training...")
        t0 = time.time()
        for epoch in range(1, NUM_EPOCHS + 1):
            optimizer.zero_grad(set_to_none=True)
            if USE_FOURIER_FEATURES and POSITIVE_HEAD == 'square':
                R_pde, _ = pde_residual_fast_nd_cached(model, pde_cache, M_EXPONENT)
            else:
                R_pde, _ = residual_and_flux(model, pde_xt, M_EXPONENT, SPATIAL_DIM)


            loss_pde = mse(R_pde, torch.zeros_like(R_pde))

            # IC loss
            u_ic_pred = model(ic_xt)
            loss_ic = mse(u_ic_pred, u_ic_true)

            # BC loss (zero flux)
            xt_req = bc_xt.clone().detach().requires_grad_(True)
            u_bc = model(xt_req)
            v_bc = u_bc.pow(M_EXPONENT)
            grads_v = torch.autograd.grad(v_bc, xt_req, torch.ones_like(v_bc), create_graph=True)[0]  # [Nb, d+1]
            v_x_bc = grads_v[:, :SPATIAL_DIM]  

            per_face = max(1, N_BC // (2 * SPATIAL_DIM))
            losses = []
            offset = 0
            for i in range(SPATIAL_DIM):
                sel = slice(offset, offset + per_face)   # x_i = 0 (normal -e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
                sel = slice(offset, offset + per_face)   # x_i = L_i (normal +e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
            loss_bc = torch.stack(losses).mean()

            # Total loss
            loss = LAMBDA_PDE * loss_pde + LAMBDA_IC * loss_ic + LAMBDA_BC * loss_bc
            loss.backward()
            optimizer.step()

            if epoch % 1000 == 0 or epoch == 1 or epoch == NUM_EPOCHS:
                _log(
                    f"Epoch {epoch:5d}/{NUM_EPOCHS} "
                    f"| Total: {loss.item():.3e} "
                    f"| PDE: {loss_pde.item():.3e} "
                    f"| IC: {loss_ic.item():.3e} "
                    f"| BC: {loss_bc.item():.3e}"
                )

        dur = time.time() - t0
        run_durations.append(dur)
        _log(f"Finished run {run_idx} in {dur:.2f}s")

        # ---- Evaluation ----
        model.eval()

        # Monte Carlo space–time L2 error 
        x_eval = rand_in_box(N_EVAL, L_VEC)
        t_eval = torch.rand(N_EVAL, 1, device=DEVICE) * T_DOMAIN
        xt_eval = torch.cat([x_eval, t_eval], dim=1)

        with torch.no_grad():
            u_pred_eval = model(xt_eval)
            u_true_eval = barenblatt_true_shifted_dd(t_eval, x_eval, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        num = torch.sum((u_pred_eval - u_true_eval) ** 2)
        den = torch.sum(u_true_eval ** 2) + 1e-16
        rel_L2_spacetime = torch.sqrt(num / den).item()
        spacetime_errors.append(rel_L2_spacetime)
        _log(f"Rel L2 (space–time, MC): {rel_L2_spacetime:.6e}")

        # Final-time spatial L2 error
        x_evalT = rand_in_box(N_EVAL, L_VEC)
        tT = torch.full((N_EVAL, 1), T_DOMAIN, device=DEVICE)
        with torch.no_grad():
            u_pred_T = model(torch.cat([x_evalT, tT], dim=1))
            u_true_T = barenblatt_true_shifted_dd(tT, x_evalT, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        numT = torch.sum((u_pred_T - u_true_T) ** 2)
        denT = torch.sum(u_true_T ** 2) + 1e-16
        rel_L2_final = torch.sqrt(numT / denT).item()
        finaltime_errors.append(rel_L2_final)
        _log(f"Rel L2 (final time T={T_DOMAIN}): {rel_L2_final:.6e}")
        _log("")


    # ---- Summary stats ----
    spacetime_np = np.array(spacetime_errors, dtype=np.float64)
    finaltime_np = np.array(finaltime_errors, dtype=np.float64)

    st_mean = float(spacetime_np.mean())
    st_std  = float(spacetime_np.std(ddof=1)) if len(spacetime_np) > 1 else 0.0
    ft_mean = float(finaltime_np.mean())
    ft_std  = float(finaltime_np.std(ddof=1)) if len(finaltime_np) > 1 else 0.0

    _log("=== Summary over runs ===")
    _log(f"Space–time rel L2: mean={st_mean:.6e}, std={st_std:.6e}")
    _log(f"Final-time rel L2: mean={ft_mean:.6e}, std={ft_std:.6e}")
    _log(f"Durations (s)    : {', '.join(f'{d:.2f}' for d in run_durations)}")
    _log(f"Saved full log to: {LOG_FILE}")

