from __future__ import annotations

"""
Multi-Dimensional PINN (no RFF) which solves PME.
This script trains a PINN to solve the PME five times, and then saves the results to a .tvt file.
"""

import math
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# -------------------------
# Global Config
# -------------------------
torch.manual_seed(42)
np.random.seed(42)

# Problem params
SPATIAL_DIM = 3                    
M_EXPONENT = 2.0                   # m in PME (m > 1)
T0_IC = 1e-2                       # small positive time for IC (avoid singularity)
T_DOMAIN = 0.01                    # T in [0, T]

# Domain Ω = ∏ [0, L_k]
L_VEC = [1.0] * SPATIAL_DIM        # lengths per dimension, defaults to [1.0]*d
M_RUNS = 1    


# PINN architecture params
USE_FOURIER_FEATURES = False       # False for standard MLP
RF_COUNT = 0                       # number of Fourier modes (first layer), feature dim = 2*RF_COUNT
FOURIER_INIT_SCALE = 10.0          # init scale for frequencies

NON_FOURIER_ACTIVATION = 'tanh'    # used when USE_FOURIER_FEATURES=False

HIDDEN_DEPTH = 2                   # number of extra hidden layers after first mapping
HIDDEN_WIDTH = 100                 # width of those hidden layers

POSITIVE_HEAD = 'square'           # 'square' | 'softplus' | 'relu' | 'none'

# Input anisotropic scaling (per spatial dim, then time)
INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

# Collocation toggles
USE_CUSTOM_COLLOCATION = False      #  False for uniform collocation
FRONT_BAND_REL = 0.02              

# Training params
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5000
N_PDE = 5000                       
N_IC = 2500
N_BC = 1250                        

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 200.0
LAMBDA_BC = 1.0

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")


# -------------------------
# Utilities
# -------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return torch.tanh
    if name == 'relu':
        return torch.relu
    if name == 'sigmoid':
        return torch.sigmoid
    raise ValueError(f"Unsupported activation: {name}")


def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)

def init_xavier(layer):
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

# -------------------------
# PINN definition 
# -------------------------
class PINNnD(nn.Module):
    def __init__(
        self,
        spatial_dim: int,
        *,
        use_fourier: bool,
        rf_count: int = 0,
        fourier_init_scale: float = 1.0,
        non_fourier_activation: str = 'tanh',
        hidden_depth: int = 0,
        hidden_width: int = 256,
        positive: str = 'square',
        input_scales: list[float] | Tuple[float, ...] = None,
        seed: int = 12345,
    ) -> None:
        super().__init__()
        torch.manual_seed(seed)

        self.d = spatial_dim
        self.input_dim = self.d + 1   # (x1,...,xd, t)
        self.output_dim = 1

        self.use_fourier = use_fourier
        self.hidden_depth = int(hidden_depth)
        self.hidden_width = int(hidden_width)
        self.positive = positive
        self.non_fourier_activation = non_fourier_activation.lower()

        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales))

        if self.use_fourier:
            assert rf_count > 0 and rf_count % 1 == 0, "rf_count must be a positive integer"
            self.rf_count = int(rf_count)
            # Trainable frequency matrix and phase
            self.freq = nn.Parameter(torch.randn(self.rf_count, self.input_dim) * fourier_init_scale)  # [rf, d+1]
            self.phase = nn.Parameter(torch.rand(self.rf_count) * 2 * math.pi)                          # [rf]
            self._build_post_fourier_stack()
        else:
            # Standard MLP
            self.act = get_activation(self.non_fourier_activation)
            layers: list[nn.Module] = []
            if self.hidden_depth <= 0:
                self.inp = nn.Linear(self.input_dim, self.output_dim)
                self.out = None
                self.hidden = None
            else:
                layers.append(nn.Linear(self.input_dim, self.hidden_width))
                for _ in range(self.hidden_depth - 1):
                    layers.append(nn.Linear(self.hidden_width, self.hidden_width))
                self.hidden = nn.ModuleList(layers)
                self.out = nn.Linear(self.hidden_width, self.output_dim)

    def _build_post_fourier_stack(self):
        base_dim = 2 * self.rf_count  # cos+sin
    
        if self.hidden_depth <= 0:
            # No hidden layers after Fourier features
            self.project = None
            self.hidden = None
            self.out = nn.Linear(base_dim, self.output_dim)
            init_xavier(self.out)
        else:
            # First layer after Fourier features
            self.project = nn.Linear(base_dim, self.hidden_width)
            init_xavier(self.project)
    
            # Hidden layers
            hidden_layers: list[nn.Module] = []
            for _ in range(self.hidden_depth - 1):
                lay = nn.Linear(self.hidden_width, self.hidden_width)
                init_xavier(lay)
                hidden_layers.append(lay)
            self.hidden = nn.ModuleList(hidden_layers)
    
            # Output layer
            self.out = nn.Linear(self.hidden_width, self.output_dim)
            init_xavier(self.out)
    
            self.act = torch.tanh


    def _first_mapping(self, xt: torch.Tensor) -> torch.Tensor:
        x_scaled = xt * self.input_scales  # [N, d+1]
        if self.use_fourier:
            proj = x_scaled @ self.freq.t() + self.phase  # [N, rf]
            cosf = torch.cos(proj)
            sinf = torch.sin(proj)
            scale = (self.rf_count) ** (-0.5)
            h0 = scale * torch.cat([cosf, sinf], dim=1)   # [N, 2*rf]
            return h0
        else:
            if self.hidden_depth <= 0:
                return x_scaled
            z = self.hidden[0](x_scaled)
            return self.act(z)

    def forward(self, xt: torch.Tensor) -> torch.Tensor:
        h = self._first_mapping(xt)
        if self.use_fourier:
            if self.hidden_depth > 0:
                h = self.act(self.project(h))
                for layer in self.hidden:
                    h = self.act(layer(h))
            raw = self.out(h)
        else:
            if self.hidden_depth <= 0:
                raw = self.inp(h)  # h is already scaled xt
            else:
                for layer in self.hidden[1:]:
                    h = self.act(layer(h))
                raw = self.out(h)

        if self.positive == 'square':
            return raw ** 2
        elif self.positive == 'softplus':
            return torch.nn.functional.softplus(raw, beta=1.0)
        elif self.positive == 'relu':
            return torch.relu(raw)
        elif self.positive == 'none':
            return raw
        else:
            raise ValueError("positive must be 'square', 'softplus', 'relu', or 'none'")


# -------------------------
# Barenblatt in d-D
# -------------------------
def barenblatt_ic_dd(
    x: torch.Tensor,           # [N, d] spatial points
    L_vec: list[float],        # domain lengths
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    Returns u(0,x) := Barenblatt(t0, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)  # generalizes 1D formula

    R0 = support_fraction * (min(L_vec) / 1.0)  # safe inside the domain
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t0 ** (2.0 * beta))
    z_clamped = torch.clamp(z, min=0.0)
    u_t0 = (t0 ** (-alpha)) * (z_clamped ** (1.0 / (m - 1.0)))
    return u_t0


def barenblatt_true_shifted_dd(
    t_model: torch.Tensor,     # [N, 1] or scalar
    x: torch.Tensor,           # [N, d]
    L_vec: list[float],
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    u_true(x, t_model) = Barenblatt(t0 + t_model, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 1.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    if t_phys.ndim == 0:
        t_phys = t_phys * torch.ones(x.shape[0], 1, device=x.device, dtype=x.dtype)

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u


# -------------------------
# Collocation helpers (some are for custom sampling which isnt used)
# -------------------------
def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L


def sample_near_front_nd(n: int, L_vec: list[float], t0: float, T: float, m: float, band_rel: float) -> torch.Tensor:
    """
    Sample near the expanding free boundary radius R(t) ≈ R0 * ((t0+t)/t0)^beta.
    """
    d = len(L_vec)
    lam = 2.0 + d * (m - 1.0)
    beta = 1.0 / lam
    Lmin = min(L_vec)
    R0 = 0.25 * (Lmin / 2.0)  # slightly smaller than IC choice
    t = (torch.rand(n, 1, device=DEVICE) ** 3) * T
    R = R0 * ((t0 + t) / t0) ** beta

    # directions - uniform on sphere via Gaussian normalization
    dirs = torch.randn(n, d, device=DEVICE)
    dirs = dirs / (dirs.norm(dim=1, keepdim=True) + 1e-12)

    delta = band_rel * Lmin
    radii = R.squeeze(1) + (2 * torch.rand(n, device=DEVICE) - 1.0) * delta
    radii = radii.clamp(min=0.0)

    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    x = x0 + dirs * radii.view(-1, 1)

    # clamp to box
    L = _to_tensor(L_vec).view(1, d)
    x = torch.clamp(x, min=torch.zeros(1, d, device=DEVICE), max=L)
    xt = torch.cat([x, t], dim=1)
    return xt



def build_collocation_sets_nd(use_custom: bool, d: int):
    # PDE points
    if use_custom:
        n_front = int(0.4 * N_PDE)
        n_bulk = N_PDE - n_front
        pde_xt_front = sample_near_front_nd(n_front, L_VEC, T0_IC, T_DOMAIN, M_EXPONENT, FRONT_BAND_REL)
        x_bulk = rand_in_box(n_bulk, L_VEC)
        t_bulk = (torch.rand(n_bulk, 1, device=DEVICE) ** 3) * T_DOMAIN
        pde_xt_bulk = torch.cat([x_bulk, t_bulk], dim=1)
        pde_xt = torch.cat([pde_xt_front, pde_xt_bulk], dim=0)
    else:
        x_bulk = rand_in_box(N_PDE, L_VEC)
        t_bulk = torch.rand(N_PDE, 1, device=DEVICE) * T_DOMAIN
        pde_xt = torch.cat([x_bulk, t_bulk], dim=1)

    # IC points @ t=0
    ic_x = rand_in_box(N_IC, L_VEC)
    ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    ic_xt = torch.cat([ic_x, ic_t], dim=1)
    u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

    # BC points on each face: x_i = 0 and x_i = L_i
    faces = []
    per_face = max(1, N_BC // (2 * d))
    for i in range(d):
        # Other coordinates uniform
        x_rest = rand_in_box(per_face, L_VEC)
        x0_face = x_rest.clone()
        xL_face = x_rest.clone()
        x0_face[:, i] = 0.0
        xL_face[:, i] = L_VEC[i]
        t_face = torch.rand(per_face, 1, device=DEVICE) * T_DOMAIN
        faces.append(torch.cat([x0_face, t_face], dim=1))
        faces.append(torch.cat([xL_face, t_face], dim=1))
    bc_xt = torch.cat(faces, dim=0)

    return pde_xt, ic_xt, u_ic_true, bc_xt


# -------------------------
# PDE residual 
# -------------------------
def pde_residual_autograd_nd(model: nn.Module, xt: torch.Tensor, m_exp: float, d: int):
    """
    Returns: R, u, grad_u (only spatial components), grad_v (only spatial components)
    where v = u^m. R = u_t - sum_i ∂_{x_i x_i} (u^m)
    """
    xt_req = xt.clone().detach().requires_grad_(True)    # [N, d+1]
    u = model(xt_req)                                    # [N, 1]

    # Gradient of u w.r.t. all inputs
    grads_u = torch.autograd.grad(u, xt_req, torch.ones_like(u), create_graph=True)[0]  # [N, d+1]
    u_x = grads_u[:, :d]   # [N, d]
    u_t = grads_u[:, d:d+1]

    v = u.pow(m_exp)       # [N, 1]
    grads_v = torch.autograd.grad(v, xt_req, torch.ones_like(v), create_graph=True)[0]  # [N, d+1]
    v_x = grads_v[:, :d]   # [N, d]

    # Laplacian of v 
    v_xx_terms = []
    for i in range(d):
        vi = v_x[:, i:i+1]  # [N,1]
        grads_vi = torch.autograd.grad(vi, xt_req, torch.ones_like(vi), create_graph=True)[0]  # [N, d+1]
        v_xx_terms.append(grads_vi[:, i:i+1])  # second derivative w.r.t. x_i
    v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N,1]

    R = u_t - v_lap
    return R, u, u_x, v_x


# -------------------------
# Build data, model, optimizer
# -------------------------
pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

model = PINNnD(
    SPATIAL_DIM,
    use_fourier=USE_FOURIER_FEATURES,
    rf_count=RF_COUNT,
    fourier_init_scale=FOURIER_INIT_SCALE,
    non_fourier_activation=NON_FOURIER_ACTIVATION,
    hidden_depth=HIDDEN_DEPTH,
    hidden_width=HIDDEN_WIDTH,
    positive=POSITIVE_HEAD,
    input_scales=INPUT_SCALES,
    seed=12345,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()


# =========================
# Multi-run Training & Evaluation
# =========================
import random
from datetime import datetime

BASE_SEED = 12345          
LOG_FILE = f"pinn_multirun_{SPATIAL_DIM}D_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

def _set_seeds(s: int):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)

# Storage for metrics
spacetime_errors = []
finaltime_errors = []
run_durations = []

with open(LOG_FILE, "w", encoding="utf-8") as logf:

    def _log(msg: str):
        print(msg)
        logf.write(msg + "\n")
        logf.flush()

    # Log configuration header
    _log(f"== {SPATIAL_DIM}D PINN for PME (u_t - Δ(u^m) = 0) ==")
    _log(f"m (M_EXPONENT)         : {M_EXPONENT}")
    _log(f"T0_IC / T_DOMAIN       : {T0_IC} / {T_DOMAIN}")
    _log(f"L_VEC                  : {L_VEC}")
    _log(f"USE_FOURIER_FEATURES  : {USE_FOURIER_FEATURES}")
    _log(f"RF_COUNT               : {RF_COUNT}")
    _log(f"HIDDEN_DEPTH / WIDTH  : {HIDDEN_DEPTH} / {HIDDEN_WIDTH}")
    _log(f"POSITIVE_HEAD         : {POSITIVE_HEAD}")
    _log(f"INPUT_SCALES          : {INPUT_SCALES}")
    _log(f"USE_CUSTOM_COLLOCATION: {USE_CUSTOM_COLLOCATION}")
    _log(f"LEARNING_RATE         : {LEARNING_RATE}")
    _log(f"NUM_EPOCHS            : {NUM_EPOCHS}")
    _log(f"N_PDE / N_IC / N_BC   : {N_PDE} / {N_IC} / {N_BC}")
    _log(f"N_EVAL (MC samples)   : 20000")
    _log(f"M_RUNS                : {M_RUNS}")
    _log("")

    for run_idx in range(1, M_RUNS + 1):
        _log(f"--- Run {run_idx}/{M_RUNS} ---")
        seed = BASE_SEED + run_idx
        _set_seeds(seed)
        _log(f"Seed: {seed}")

        # Fresh collocation sets each run
        pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

        # restart model & optimizer 
        model = PINNnD(
            SPATIAL_DIM,
            use_fourier=USE_FOURIER_FEATURES,
            rf_count=RF_COUNT,
            fourier_init_scale=FOURIER_INIT_SCALE,
            non_fourier_activation=NON_FOURIER_ACTIVATION,
            hidden_depth=HIDDEN_DEPTH,
            hidden_width=HIDDEN_WIDTH,
            positive=POSITIVE_HEAD,
            input_scales=INPUT_SCALES,
            seed=seed,
        ).to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        mse = nn.MSELoss()

        # ---- Training ----
        _log("Training...")
        t0 = time.time()
        for epoch in range(1, NUM_EPOCHS + 1):
            optimizer.zero_grad(set_to_none=True)

            # PDE 
            R_pde, _, _, _ = pde_residual_autograd_nd(model, pde_xt, M_EXPONENT, SPATIAL_DIM)
            loss_pde = mse(R_pde, torch.zeros_like(R_pde))

            # IC
            u_ic_pred = model(ic_xt)
            loss_ic = mse(u_ic_pred, u_ic_true)

            # BC zero-flux
            _, _, _, v_x_bc = pde_residual_autograd_nd(model, bc_xt, M_EXPONENT, SPATIAL_DIM)
            per_face = max(1, N_BC // (2 * SPATIAL_DIM))
            losses = []
            offset = 0
            for i in range(SPATIAL_DIM):
                sel = slice(offset, offset + per_face)  # x_i = 0 face (normal -e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
                sel = slice(offset, offset + per_face)  # x_i = L_i face (normal +e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
            loss_bc = torch.stack(losses).mean()

            # Total loss
            loss = LAMBDA_PDE * loss_pde + LAMBDA_IC * loss_ic + LAMBDA_BC * loss_bc
            loss.backward()
            optimizer.step()

            if epoch % 1000 == 0 or epoch == 1 or epoch == NUM_EPOCHS:
                _log(
                    f"Epoch {epoch:5d}/{NUM_EPOCHS} "
                    f"| Total: {loss.item():.3e} "
                    f"| PDE: {loss_pde.item():.3e} "
                    f"| IC: {loss_ic.item():.3e} "
                    f"| BC: {loss_bc.item():.3e}"
                )

        dur = time.time() - t0
        run_durations.append(dur)
        _log(f"Finished run {run_idx} in {dur:.2f}s")

        # ---- Evaluation ----
        model.eval()

        # Monte Carlo space–time L2 error 
        N_EVAL = 20000
        x_eval = rand_in_box(N_EVAL, L_VEC)
        t_eval = torch.rand(N_EVAL, 1, device=DEVICE) * T_DOMAIN
        xt_eval = torch.cat([x_eval, t_eval], dim=1)

        with torch.no_grad():
            u_pred_eval = model(xt_eval)
            u_true_eval = barenblatt_true_shifted_dd(t_eval, x_eval, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        num = torch.sum((u_pred_eval - u_true_eval) ** 2)
        den = torch.sum(u_true_eval ** 2) + 1e-16
        rel_L2_spacetime = torch.sqrt(num / den).item()
        spacetime_errors.append(rel_L2_spacetime)
        _log(f"Rel L2 (space–time, MC): {rel_L2_spacetime:.6e}")

        # Final-time spatial L2 error
        x_evalT = rand_in_box(N_EVAL, L_VEC)
        tT = torch.full((N_EVAL, 1), T_DOMAIN, device=DEVICE)
        with torch.no_grad():
            u_pred_T = model(torch.cat([x_evalT, tT], dim=1))
            u_true_T = barenblatt_true_shifted_dd(tT, x_evalT, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        numT = torch.sum((u_pred_T - u_true_T) ** 2)
        denT = torch.sum(u_true_T ** 2) + 1e-16
        rel_L2_final = torch.sqrt(numT / denT).item()
        finaltime_errors.append(rel_L2_final)
        _log(f"Rel L2 (final time T={T_DOMAIN}): {rel_L2_final:.6e}")
        _log("")


    # ---- Summary stats ----
    spacetime_np = np.array(spacetime_errors, dtype=np.float64)
    finaltime_np = np.array(finaltime_errors, dtype=np.float64)

    st_mean = float(spacetime_np.mean())
    st_std  = float(spacetime_np.std(ddof=1)) if len(spacetime_np) > 1 else 0.0
    ft_mean = float(finaltime_np.mean())
    ft_std  = float(finaltime_np.std(ddof=1)) if len(finaltime_np) > 1 else 0.0

    _log("=== Summary over runs ===")
    _log(f"Space–time rel L2: mean={st_mean:.6e}, std={st_std:.6e}")
    _log(f"Final-time rel L2: mean={ft_mean:.6e}, std={ft_std:.6e}")
    _log(f"Durations (s)    : {', '.join(f'{d:.2f}' for d in run_durations)}")
    _log(f"Saved full log to: {LOG_FILE}")
