from __future__ import annotations

"""
Multi-Dimensional PME PINN with Fourier features.
This script trains a PINN to solve the PME five times, and then saves the results to a .tvt file.
"""

import math
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# -------------------------
# Global params
# -------------------------
torch.manual_seed(42)
np.random.seed(42)

# Problem params
SPATIAL_DIM = 3                   
M_EXPONENT = 2.0                   # m in PME (m > 1)
T0_IC = 1e-2                       # small positive time for IC (avoid singularity)
T_DOMAIN = 0.025                    # T in [0, T]
M_RUNS = 1

# Domain
L_VEC = [1.0] * SPATIAL_DIM        # lengths per dimension, default is [1.0]*d

# PINN architecture params
USE_FOURIER_FEATURES = True        # set False for standard MLP
RF_COUNT = 1000                    # number of Fourier modes (first layer), feature dim = 2*RF_COUNT
FOURIER_INIT_SCALE = 10.0          # init scale for frequencies

NON_FOURIER_ACTIVATION = 'tanh'    # used when USE_FOURIER_FEATURES=False

HIDDEN_DEPTH = 0                   # number of extra hidden layers after first mapping
HIDDEN_WIDTH = 0                   # width of those hidden layers

POSITIVE_HEAD = 'square'           # 'square' | 'softplus' | 'relu' | 'none'

# Input anisotropic scaling (per spatial dim, then time)
INPUT_SCALES = [1.0] * SPATIAL_DIM + [10.0]

# Collocation toggles
USE_CUSTOM_COLLOCATION = False      #  False means uniform collocation
FRONT_BAND_REL = 0.02             

# Training params
LEARNING_RATE = 1e-3
NUM_EPOCHS = 5000
N_PDE = 6000                      
N_IC = 4000
N_BC = 2000                        

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 200.0
LAMBDA_BC = 1.0

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")


# -------------------------
# Utilities
# -------------------------
def get_activation(name: str):
    name = name.lower()
    if name == 'tanh':
        return torch.tanh
    if name == 'relu':
        return torch.relu
    if name == 'sigmoid':
        return torch.sigmoid
    raise ValueError(f"Unsupported activation: {name}")


def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)

def init_xavier(layer):
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)

# -------------------------
# PINN definition 
# -------------------------
class PINNnD(nn.Module):
    def __init__(
        self,
        spatial_dim: int,
        *,
        use_fourier: bool,
        rf_count: int = 0,
        fourier_init_scale: float = 1.0,
        non_fourier_activation: str = 'tanh',
        hidden_depth: int = 0,
        hidden_width: int = 256,
        positive: str = 'square',
        input_scales: list[float] | Tuple[float, ...] = None,
        seed: int = 12345,
    ) -> None:
        super().__init__()
        torch.manual_seed(seed)

        self.d = spatial_dim
        self.input_dim = self.d + 1   # (x1,...,xd, t)
        self.output_dim = 1

        self.use_fourier = use_fourier
        self.hidden_depth = int(hidden_depth)
        self.hidden_width = int(hidden_width)
        self.positive = positive
        self.non_fourier_activation = non_fourier_activation.lower()

        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales))

        if self.use_fourier:
            assert rf_count > 0 and rf_count % 1 == 0, "rf_count must be a positive integer"
            self.rf_count = int(rf_count)
            # Trainable frequency matrix and phase
            self.freq = nn.Parameter(torch.randn(self.rf_count, self.input_dim) * fourier_init_scale)  # [rf, d+1]
            self.phase = nn.Parameter(torch.rand(self.rf_count) * 2 * math.pi)                          # [rf]
            self._build_post_fourier_stack()
        else:
            # Standard MLP
            self.act = get_activation(self.non_fourier_activation)
            layers: list[nn.Module] = []
            if self.hidden_depth <= 0:
                self.inp = nn.Linear(self.input_dim, self.output_dim)
                self.out = None
                self.hidden = None
            else:
                layers.append(nn.Linear(self.input_dim, self.hidden_width))
                for _ in range(self.hidden_depth - 1):
                    layers.append(nn.Linear(self.hidden_width, self.hidden_width))
                self.hidden = nn.ModuleList(layers)
                self.out = nn.Linear(self.hidden_width, self.output_dim)

    # def _build_post_fourier_stack(self):
    #     base_dim = 2 * self.rf_count  # cos+sin
    #     if self.hidden_depth <= 0:
    #         self.project = None
    #         self.hidden = None
    #         self.out = nn.Linear(base_dim, self.output_dim)
    #     else:
    #         self.project = nn.Linear(base_dim, self.hidden_width)
    #         hidden_layers: list[nn.Module] = []
    #         for _ in range(self.hidden_depth - 1):
    #             hidden_layers.append(nn.Linear(self.hidden_width, self.hidden_width))
    #         self.hidden = nn.ModuleList(hidden_layers)
    #         self.out = nn.Linear(self.hidden_width, self.output_dim)
    #         self.act = torch.tanh


    def _build_post_fourier_stack(self):
        base_dim = 2 * self.rf_count  # cos+sin
    
        if self.hidden_depth <= 0:
            # No hidden layers after Fourier features
            self.project = None
            self.hidden = None
            self.out = nn.Linear(base_dim, self.output_dim)
            init_xavier(self.out)
        else:
            # First layer after Fourier features
            self.project = nn.Linear(base_dim, self.hidden_width)
            init_xavier(self.project)
    
            # Hidden layers
            hidden_layers: list[nn.Module] = []
            for _ in range(self.hidden_depth - 1):
                lay = nn.Linear(self.hidden_width, self.hidden_width)
                init_xavier(lay)
                hidden_layers.append(lay)
            self.hidden = nn.ModuleList(hidden_layers)
    
            # Output layer
            self.out = nn.Linear(self.hidden_width, self.output_dim)
            init_xavier(self.out)
    
            self.act = torch.tanh


    def _first_mapping(self, xt: torch.Tensor) -> torch.Tensor:
        x_scaled = xt * self.input_scales  # [N, d+1]
        if self.use_fourier:
            proj = x_scaled @ self.freq.t() + self.phase  # [N, rf]
            cosf = torch.cos(proj)
            sinf = torch.sin(proj)
            scale = (self.rf_count) ** (-0.5)
            h0 = scale * torch.cat([cosf, sinf], dim=1)   # [N, 2*rf]
            return h0
        else:
            if self.hidden_depth <= 0:
                return x_scaled
            z = self.hidden[0](x_scaled)
            return self.act(z)

    def forward(self, xt: torch.Tensor) -> torch.Tensor:
        h = self._first_mapping(xt)
        if self.use_fourier:
            if self.hidden_depth > 0:
                h = self.act(self.project(h))
                for layer in self.hidden:
                    h = self.act(layer(h))
            raw = self.out(h)
        else:
            if self.hidden_depth <= 0:
                raw = self.inp(h)  # h is already scaled xt
            else:
                for layer in self.hidden[1:]:
                    h = self.act(layer(h))
                raw = self.out(h)

        if self.positive == 'square':
            return raw ** 2
        elif self.positive == 'softplus':
            return torch.nn.functional.softplus(raw, beta=1.0)
        elif self.positive == 'relu':
            return torch.relu(raw)
        elif self.positive == 'none':
            return raw
        else:
            raise ValueError("positive must be 'square', 'softplus', 'relu', or 'none'")


# -------------------------
# Barenblatt profile
# -------------------------
def barenblatt_ic_dd(
    x: torch.Tensor,           # [N, d] spatial points
    L_vec: list[float],        # domain lengths
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    Returns u(0,x) := Barenblatt(t0, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)  # generalizes 1D formula

    R0 = support_fraction * (min(L_vec) / 1.0)  # safe inside the domain
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t0 ** (2.0 * beta))
    z_clamped = torch.clamp(z, min=0.0)
    u_t0 = (t0 ** (-alpha)) * (z_clamped ** (1.0 / (m - 1.0)))
    return u_t0


def barenblatt_true_shifted_dd(
    t_model: torch.Tensor,     # [N, 1] or scalar
    x: torch.Tensor,           # [N, d]
    L_vec: list[float],
    m: float,
    t0: float,
    support_fraction: float = 0.25,
    x_center: torch.Tensor | None = None,  # [d]
) -> torch.Tensor:
    """
    u_true(x, t_model) = Barenblatt(t0 + t_model, x).
    """
    d = x.shape[1]
    if x_center is None:
        x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    else:
        x0 = x_center.view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 1.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    if t_phys.ndim == 0:
        t_phys = t_phys * torch.ones(x.shape[0], 1, device=x.device, dtype=x.dtype)

    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u


# -------------------------
# Collocation helpers (many only used for custom collocation)
# -------------------------
def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L


def sample_near_front_nd(n: int, L_vec: list[float], t0: float, T: float, m: float, band_rel: float) -> torch.Tensor:
    """
    Sample near the expanding free boundary radius R(t) ≈ R0 * ((t0+t)/t0)^beta.
    """
    d = len(L_vec)
    lam = 2.0 + d * (m - 1.0)
    beta = 1.0 / lam
    Lmin = min(L_vec)
    R0 = 0.25 * (Lmin / 2.0)  # slightly smaller than IC choice
    t = (torch.rand(n, 1, device=DEVICE) ** 3) * T
    R = R0 * ((t0 + t) / t0) ** beta

    # directions ~ uniform on sphere via Gaussian normalization
    dirs = torch.randn(n, d, device=DEVICE)
    dirs = dirs / (dirs.norm(dim=1, keepdim=True) + 1e-12)

    delta = band_rel * Lmin
    radii = R.squeeze(1) + (2 * torch.rand(n, device=DEVICE) - 1.0) * delta
    radii = radii.clamp(min=0.0)

    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)
    x = x0 + dirs * radii.view(-1, 1)

    # clamp to box
    L = _to_tensor(L_vec).view(1, d)
    x = torch.clamp(x, min=torch.zeros(1, d, device=DEVICE), max=L)
    xt = torch.cat([x, t], dim=1)
    return xt



def build_collocation_sets_nd(use_custom: bool, d: int):
    # PDE points
    if use_custom:
        n_front = int(0.4 * N_PDE)
        n_bulk = N_PDE - n_front
        pde_xt_front = sample_near_front_nd(n_front, L_VEC, T0_IC, T_DOMAIN, M_EXPONENT, FRONT_BAND_REL)
        x_bulk = rand_in_box(n_bulk, L_VEC)
        t_bulk = (torch.rand(n_bulk, 1, device=DEVICE) ** 3) * T_DOMAIN
        pde_xt_bulk = torch.cat([x_bulk, t_bulk], dim=1)
        pde_xt = torch.cat([pde_xt_front, pde_xt_bulk], dim=0)
    else:
        x_bulk = rand_in_box(N_PDE, L_VEC)
        t_bulk = torch.rand(N_PDE, 1, device=DEVICE) * T_DOMAIN
        pde_xt = torch.cat([x_bulk, t_bulk], dim=1)

    # IC points @ t=0 (use Barenblatt at t0 for initial state, sampled over space)
    ic_x = rand_in_box(N_IC, L_VEC)
    ic_t = torch.zeros(N_IC, 1, device=DEVICE)
    ic_xt = torch.cat([ic_x, ic_t], dim=1)
    u_ic_true = barenblatt_ic_dd(ic_x, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

    # BC points on each face: x_i = 0 and x_i = L_i
    faces = []
    per_face = max(1, N_BC // (2 * d))
    for i in range(d):
        # Other coordinates uniform
        x_rest = rand_in_box(per_face, L_VEC)
        x0_face = x_rest.clone()
        xL_face = x_rest.clone()
        x0_face[:, i] = 0.0
        xL_face[:, i] = L_VEC[i]
        t_face = torch.rand(per_face, 1, device=DEVICE) * T_DOMAIN
        faces.append(torch.cat([x0_face, t_face], dim=1))
        faces.append(torch.cat([xL_face, t_face], dim=1))
    bc_xt = torch.cat(faces, dim=0)

    return pde_xt, ic_xt, u_ic_true, bc_xt


# -------------------------
# PDE residual 
# -------------------------
def pde_residual_autograd_nd(model: nn.Module, xt: torch.Tensor, m_exp: float, d: int):
    """
    Returns: R, u, grad_u (only spatial components), grad_v (only spatial components)
    where v = u^m
    """
    xt_req = xt.clone().detach().requires_grad_(True)    # [N, d+1]
    u = model(xt_req)                                    # [N, 1]

    # Gradient of u w.r.t. all inputs (x_1,...,x_d,t)
    grads_u = torch.autograd.grad(u, xt_req, torch.ones_like(u), create_graph=True)[0]  # [N, d+1]
    u_x = grads_u[:, :d]   # [N, d]
    u_t = grads_u[:, d:d+1]

    v = u.pow(m_exp)       # [N, 1]
    grads_v = torch.autograd.grad(v, xt_req, torch.ones_like(v), create_graph=True)[0]  # [N, d+1]
    v_x = grads_v[:, :d]   # [N, d]

    # Laplacian of v 
    v_xx_terms = []
    for i in range(d):
        vi = v_x[:, i:i+1]  # [N,1]
        grads_vi = torch.autograd.grad(vi, xt_req, torch.ones_like(vi), create_graph=True)[0]  # [N, d+1]
        v_xx_terms.append(grads_vi[:, i:i+1])  # second derivative w.r.t. x_i
    v_lap = torch.stack(v_xx_terms, dim=0).sum(dim=0)  # [N,1]

    R = u_t - v_lap
    return R, u, u_x, v_x


# -------------------------
# Build data, model, optimizer
# -------------------------
pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

model = PINNnD(
    SPATIAL_DIM,
    use_fourier=USE_FOURIER_FEATURES,
    rf_count=RF_COUNT,
    fourier_init_scale=FOURIER_INIT_SCALE,
    non_fourier_activation=NON_FOURIER_ACTIVATION,
    hidden_depth=HIDDEN_DEPTH,
    hidden_width=HIDDEN_WIDTH,
    positive=POSITIVE_HEAD,
    input_scales=INPUT_SCALES,
    seed=12345,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()


# =========================
# Multi-run Training & Evaluation
# =========================
import random
from datetime import datetime


BASE_SEED = 12345      
LOG_FILE = f"pinn_multirun_{SPATIAL_DIM}D_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

def _set_seeds(s: int):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)

# Storage for metrics
spacetime_errors = []
finaltime_errors = []
run_durations = []

with open(LOG_FILE, "w", encoding="utf-8") as logf:

    def _log(msg: str):
        print(msg)
        logf.write(msg + "\n")
        logf.flush()

    # Log configuration header
    _log(f"== {SPATIAL_DIM}D PINN for PME (u_t - Δ(u^m) = 0) ==")
    _log(f"m (M_EXPONENT)         : {M_EXPONENT}")
    _log(f"T0_IC / T_DOMAIN       : {T0_IC} / {T_DOMAIN}")
    _log(f"L_VEC                  : {L_VEC}")
    _log(f"USE_FOURIER_FEATURES  : {USE_FOURIER_FEATURES}")
    _log(f"RF_COUNT               : {RF_COUNT}")
    _log(f"HIDDEN_DEPTH / WIDTH  : {HIDDEN_DEPTH} / {HIDDEN_WIDTH}")
    _log(f"POSITIVE_HEAD         : {POSITIVE_HEAD}")
    _log(f"INPUT_SCALES          : {INPUT_SCALES}")
    _log(f"USE_CUSTOM_COLLOCATION: {USE_CUSTOM_COLLOCATION}")
    _log(f"LEARNING_RATE         : {LEARNING_RATE}")
    _log(f"NUM_EPOCHS            : {NUM_EPOCHS}")
    _log(f"N_PDE / N_IC / N_BC   : {N_PDE} / {N_IC} / {N_BC}")
    _log(f"N_EVAL (MC samples)   : 20000")
    _log(f"M_RUNS                : {M_RUNS}")
    _log("")

    for run_idx in range(1, M_RUNS + 1):
        _log(f"--- Run {run_idx}/{M_RUNS} ---")
        seed = BASE_SEED + run_idx
        _set_seeds(seed)
        _log(f"Seed: {seed}")

        # Fresh collocation sets each run
        pde_xt, ic_xt, u_ic_true, bc_xt = build_collocation_sets_nd(USE_CUSTOM_COLLOCATION, SPATIAL_DIM)

        # Fresh model & optimizer each run
        model = PINNnD(
            SPATIAL_DIM,
            use_fourier=USE_FOURIER_FEATURES,
            rf_count=RF_COUNT,
            fourier_init_scale=FOURIER_INIT_SCALE,
            non_fourier_activation=NON_FOURIER_ACTIVATION,
            hidden_depth=HIDDEN_DEPTH,
            hidden_width=HIDDEN_WIDTH,
            positive=POSITIVE_HEAD,
            input_scales=INPUT_SCALES,
            seed=seed,
        ).to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        mse = nn.MSELoss()

        # ---- Training ----
        _log("Training...")
        t0 = time.time()
        for epoch in range(1, NUM_EPOCHS + 1):
            optimizer.zero_grad(set_to_none=True)

            # PDE residual
            R_pde, _, _, _ = pde_residual_autograd_nd(model, pde_xt, M_EXPONENT, SPATIAL_DIM)
            loss_pde = mse(R_pde, torch.zeros_like(R_pde))

            # IC
            u_ic_pred = model(ic_xt)
            loss_ic = mse(u_ic_pred, u_ic_true)

            # BC zero-flux
            _, _, _, v_x_bc = pde_residual_autograd_nd(model, bc_xt, M_EXPONENT, SPATIAL_DIM)
            per_face = max(1, N_BC // (2 * SPATIAL_DIM))
            losses = []
            offset = 0
            for i in range(SPATIAL_DIM):
                sel = slice(offset, offset + per_face)  # x_i = 0 face (normal -e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
                sel = slice(offset, offset + per_face)  # x_i = L_i face (normal +e_i)
                losses.append(mse(v_x_bc[sel, i:i+1], torch.zeros(per_face, 1, device=DEVICE)))
                offset += per_face
            loss_bc = torch.stack(losses).mean()

            # Total loss
            loss = LAMBDA_PDE * loss_pde + LAMBDA_IC * loss_ic + LAMBDA_BC * loss_bc
            loss.backward()
            optimizer.step()

            if epoch % 1000 == 0 or epoch == 1 or epoch == NUM_EPOCHS:
                _log(
                    f"Epoch {epoch:5d}/{NUM_EPOCHS} "
                    f"| Total: {loss.item():.3e} "
                    f"| PDE: {loss_pde.item():.3e} "
                    f"| IC: {loss_ic.item():.3e} "
                    f"| BC: {loss_bc.item():.3e}"
                )

        dur = time.time() - t0
        run_durations.append(dur)
        _log(f"Finished run {run_idx} in {dur:.2f}s")

        # ---- Evaluation ----
        model.eval()

        # Monte Carlo space–time L2 error 
        N_EVAL = 20000
        x_eval = rand_in_box(N_EVAL, L_VEC)
        t_eval = torch.rand(N_EVAL, 1, device=DEVICE) * T_DOMAIN
        xt_eval = torch.cat([x_eval, t_eval], dim=1)

        with torch.no_grad():
            u_pred_eval = model(xt_eval)
            u_true_eval = barenblatt_true_shifted_dd(t_eval, x_eval, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        num = torch.sum((u_pred_eval - u_true_eval) ** 2)
        den = torch.sum(u_true_eval ** 2) + 1e-16
        rel_L2_spacetime = torch.sqrt(num / den).item()
        spacetime_errors.append(rel_L2_spacetime)
        _log(f"Rel L2 (space–time, MC): {rel_L2_spacetime:.6e}")

        # Final-time spatial L2 error
        x_evalT = rand_in_box(N_EVAL, L_VEC)
        tT = torch.full((N_EVAL, 1), T_DOMAIN, device=DEVICE)
        with torch.no_grad():
            u_pred_T = model(torch.cat([x_evalT, tT], dim=1))
            u_true_T = barenblatt_true_shifted_dd(tT, x_evalT, L_VEC, M_EXPONENT, T0_IC, support_fraction=0.25)

        numT = torch.sum((u_pred_T - u_true_T) ** 2)
        denT = torch.sum(u_true_T ** 2) + 1e-16
        rel_L2_final = torch.sqrt(numT / denT).item()
        finaltime_errors.append(rel_L2_final)
        _log(f"Rel L2 (final time T={T_DOMAIN}): {rel_L2_final:.6e}")
        _log("")


    # ---- Summary stats ----
    spacetime_np = np.array(spacetime_errors, dtype=np.float64)
    finaltime_np = np.array(finaltime_errors, dtype=np.float64)

    st_mean = float(spacetime_np.mean())
    st_std  = float(spacetime_np.std(ddof=1)) if len(spacetime_np) > 1 else 0.0
    ft_mean = float(finaltime_np.mean())
    ft_std  = float(finaltime_np.std(ddof=1)) if len(finaltime_np) > 1 else 0.0

    _log("=== Summary over runs ===")
    _log(f"Space–time rel L2: mean={st_mean:.6e}, std={st_std:.6e}")
    _log(f"Final-time rel L2: mean={ft_mean:.6e}, std={ft_std:.6e}")
    _log(f"Durations (s)    : {', '.join(f'{d:.2f}' for d in run_durations)}")
    _log(f"Saved full log to: {LOG_FILE}")
