from src.datasets.mesh.mesh_gen import MeshGen,mesh_to_pyg_graph
#from src.equations.poisson import PoissonEquation
from Equations.poisson import PoissonEquation
from Dataset.mesh_generator import Gmsh
from src.datasets.generators.poisson import PoissonGen
from Trainer.models import SIGN, FrequencyMLPEncoder
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator, LogFormatterSciNotation, ScalarFormatter
import os
from dataclasses import dataclass, field
from typing import List, Union
from itertools import product
from copy import deepcopy
import meshio
import numpy as np
from tqdm import tqdm


#########################
# Model Utilities
#########################
class Mock(nn.Module):
    def __init__(self, num_nodes:int):
        super().__init__() 
        self.u = nn.Parameter(torch.randn(num_nodes, 1))
    def forward(self, graph):
        return self.u

class MLP(nn.Module):
    """SIREN MLP with sinusoidal activation functions.
    
    Supports two modes:
    - in_channels=2: Input [x, y] only (original behavior)
    - in_channels=3: Input [x, y, f] (with source function)
    """
    def __init__(self, omega_0=30.0, in_channels=2):
        super().__init__()
        self.omega_0 = omega_0
        self.in_channels = in_channels
        self.f = None  # Source function, set via set_f()
        
        self.layers = nn.ModuleList([
            nn.Linear(in_channels, 64),
            nn.Linear(64, 64),
            nn.Linear(64, 64),
            nn.Linear(64, 64),
            nn.Linear(64, 1)
        ])
        # SIREN 专用初始化
        self._init_weights()
    
    def _init_weights(self):
        with torch.no_grad():
            self.layers[0].weight.uniform_(-1/3, 1/3)
            for layer in self.layers[1:]:
                bound = np.sqrt(6 / layer.weight.shape[1]) / self.omega_0
                layer.weight.uniform_(-bound, bound)
    
    def set_f(self, f: torch.Tensor):
        """Set source function for forward pass (only used when in_channels=3)."""
        self.f = f
    
    def forward(self, graph):
        coords = graph.x.float()  # [N, 2]
        
        if self.in_channels == 3:
            # Use [x, y, f] as input
            if self.f is None:
                raise RuntimeError("Must call set_f(f) before forward() when in_channels=3")
            f = self.f.unsqueeze(-1) if self.f.dim() == 1 else self.f  # [N, 1]
            x = torch.cat([coords, f.float()], dim=-1)  # [N, 3]
        else:
            # Use [x, y] only (original behavior)
            x = coords
        
        for i, layer in enumerate(self.layers[:-1]):
            x = torch.sin(self.omega_0 * layer(x))
        x = self.layers[-1](x)
        return x.squeeze()


class ResidualMLP(nn.Module):
    """MLP with residual connections and GELU activation, more stable for training."""
    def __init__(self, hidden_dim=64, num_layers=4):
        super().__init__()
        self.input_layer = nn.Linear(2, hidden_dim)
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(hidden_dim, 1)
        self.activation = nn.GELU()
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, graph):
        x = graph.x.float()
        x = self.activation(self.input_layer(x))
        
        for layer in self.hidden_layers:
            # Residual connection: gradient can flow directly through skip connection
            x = x + self.activation(layer(x))
        
        return self.output_layer(x).squeeze()


class SIGNWrapper(nn.Module):
    """Wrapper for SIGN model to match the interface of MLP models.
    
    Input: concatenation of coordinates (x, y) and source function f
    Output: u (solution)
    
    Supports SIREN activation for better convergence on PDE problems.
    Optionally uses FrequencyMLPEncoder to encode input features.
    """
    def __init__(self, num_hidden=64, num_layers=3, num_hops=8, activation="relu", omega_0=30.0,
                 encoder_L=4, decoder_L=4):
        """
        Args:
            num_hidden: Hidden dimension for SIGN and encoder/decoder
            num_layers: Number of layers in SIGN
            num_hops: Number of graph hops in SIGN
            activation: Activation function ("siren", "relu", etc.)
            omega_0: SIREN frequency parameter
            encoder_L: Frequency encoder L parameter. 0 = no encoder (default).
            decoder_layers: Number of decoder MLP layers. 0 = no decoder (default).
        """
        super().__init__()
        self.omega_0 = omega_0
        self.activation = activation
        self.encoder_L = encoder_L
        self.decoder_L = decoder_L
        self.num_hidden = num_hidden
        
        # Input features: [x, y, f] = 3 features
        input_features = 1
        
        # Optional frequency encoder
        if encoder_L > 0:
            self.encoder = FrequencyMLPEncoder(
                num_features=input_features,  # [x, y, f]
                num_classes=num_hidden,       # Output to hidden dim
                L=encoder_L,
                num_hidden=num_hidden,
                num_layers=2,                 # Default 2 layers
                activation= self.activation
            )
            sign_input_dim = num_hidden
        else:
            self.encoder = None
            sign_input_dim = input_features
        
        # Optional decoder
        if decoder_L > 0:
            sign_output_dim = num_hidden
            self.decoder = FrequencyMLPEncoder(
                num_features=num_hidden,
                num_classes=1,
                L=decoder_L,
                num_hidden=num_hidden,
                num_layers=2,
                activation= self.activation
            )
        else:
            self.decoder = None
            sign_output_dim = 1
        
        # SIGN model
        self.sign = SIGN(
            num_features=sign_input_dim,
            num_classes=sign_output_dim, 
            num_hidden=num_hidden, 
            num_layers=num_layers, 
            num_hops=num_hops,
            activation=activation
        )
        
        # Store data for forward pass
        self.f = None
        self.coords = None  # [n_nodes, 2] - (x, y) coordinates
        self.edge_index = None
        
        # Apply SIREN-style initialization if using siren activation and no encoder
        if activation == "siren" and encoder_L == 0:
            self._siren_init()
    
    def _siren_init(self):
        """Apply SIREN-style initialization to all linear layers in SIGN."""
        with torch.no_grad():
            first_layer_done = False
            for name, module in self.sign.named_modules():
                if isinstance(module, nn.Linear):
                    # First layer (input layer) uses different initialization
                    if not first_layer_done and module.weight.shape[1] == 3:
                        module.weight.uniform_(-1/3, 1/3)
                        first_layer_done = True
                    else:
                        bound = np.sqrt(6 / module.weight.shape[1]) / self.omega_0
                        module.weight.uniform_(-bound, bound)
    
    def set_data(self, f: torch.Tensor, edge_index: torch.Tensor, coords: torch.Tensor):
        """Set source function, edge index and coordinates for training.
        
        Args:
            f: Source function [n_nodes] or [n_nodes, 1]
            edge_index: Graph edge index [2, n_edges]
            coords: Node coordinates [n_nodes, 2]
        """
        self.f = f.unsqueeze(-1) if f.dim() == 1 else f  # [n_nodes, 1]
        self.edge_index = edge_index
        self.coords = coords  # [n_nodes, 2]
    
    def forward(self, graph):
        """Forward pass using concatenated [x, y, f] as input.
        
        Data flow:
            [x, y, f] -> (optional) FreqEncoder -> SIGN -> (optional) Decoder -> u
        """
        if self.f is None or self.edge_index is None or self.coords is None:
            raise RuntimeError("Must call set_data(f, edge_index, coords) before forward()")
        
        # Concatenate coordinates and source function: [n_nodes, 3]
        # Ensure float32 dtype for consistency with model weights
        # x = torch.cat([self.coords.float(), self.f.float()], dim=-1)  # [x, y, f]
        x = self.f
        # Optional frequency encoder
        if self.encoder is not None:
            x = self.encoder(x)  # [n_nodes, num_hidden]
        
        # SIGN forward
        u = self.sign(x, self.edge_index)
        
        # Optional decoder
        if self.decoder is not None:
            u = self.decoder(u)  # [n_nodes, 1]
        
        return u.squeeze(-1)  # [n_nodes]


def compute_grad_norm(model: nn.Module) -> float:
    """Compute total gradient norm across all parameters."""
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            total_norm += param.grad.data.norm(2).item() ** 2
    return total_norm ** 0.5


def compute_l2_error(u_pred: torch.Tensor, u_true: torch.Tensor) -> float:
    """
    Compute relative L2 error: ||u_pred - u_true||_2 / ||u_true||_2
    
    Args:
        u_pred: Predicted solution [n_nodes]
        u_true: Ground truth solution [n_nodes]
        
    Returns:
        Relative L2 error (scalar)
    """
    diff_norm = torch.norm(u_pred - u_true, p=2)
    true_norm = torch.norm(u_true, p=2)
    return (diff_norm / (true_norm + 1e-10)).item()


#########################
# Loss Landscape Utilities
#########################
@dataclass
class GroupTensors:
    """Helper class for tensor group operations (for Hessian eigenvector computation)."""
    tensors: List[torch.Tensor]
    
    def __add__(self, other: "GroupTensors"):
        return GroupTensors([t1 + t2 for t1, t2 in zip(self.tensors, other.tensors)])
    
    def __mul__(self, other: Union[torch.Tensor, "GroupTensors", float, int]):
        if isinstance(other, GroupTensors):
            return GroupTensors([t1 * t2 for t1, t2 in zip(self.tensors, other.tensors)])
        else:
            return GroupTensors([t * other for t in self.tensors])
    
    def __rmul__(self, other: Union[torch.Tensor, "GroupTensors", float, int]):
        return self.__mul__(other)
    
    def __truediv__(self, other: "GroupTensors"):
        return GroupTensors([t1 / t2 for t1, t2 in zip(self.tensors, other.tensors)])
    
    def __sub__(self, other: "GroupTensors"):
        return GroupTensors([t1 - t2 for t1, t2 in zip(self.tensors, other.tensors)])
    
    def __matmul__(self, other: "GroupTensors"):
        return sum([(t1 * t2).sum() for t1, t2 in zip(self.tensors, other.tensors)])
    
    def __len__(self):
        return len(self.tensors)
    
    def __getitem__(self, idx):
        return self.tensors[idx]
    
    def __setitem__(self, idx, value):
        self.tensors[idx] = value
    
    def sum(self):
        return sum([t.sum() for t in self.tensors])
    
    @property
    def device(self):
        return self.tensors[0].device
    
    @property
    def norm(self):
        return (self * self).sum() ** 0.5
    
    @staticmethod
    def randn_like(other: "GroupTensors"):
        return GroupTensors([torch.randn_like(t).to(other.device) for t in other.tensors])
    
    def orth(self, others: List["GroupTensors"]):
        """Orthogonalize against a list of vectors using Gram-Schmidt."""
        result = self
        for other in others:
            result = result - (result @ other) * other
        return result
    
    def normalize(self):
        return GroupTensors([t / (self.norm + 1e-6) for t in self.tensors])

#########################
# Galerkin Loss Computation
#########################
def compute_galerkin_loss(model: nn.Module, data: "PoissonData", use_fast: bool = False) -> torch.Tensor:
    """Compute Galerkin loss for a given model and data.
    
    Args:
        model: Neural network model
        data: PoissonData instance
        use_fast: If True, use precomputed stiffness matrix (Au - b).
                  If False, use original vmap-based computation.
    """
    u = model(data.graph)
    u_flat = u.squeeze(-1)
    u_bc = u_flat.clone()
    if use_fast:
        # Fast version: R = Au - b (shorter computation graph)
        residual = data.equation.compute_residual_fast(u_bc, data.f)
    else:
        # Original version: full Galerkin integration (longer computation graph)
        residual = data.equation.compute_residual(u_bc, f=data.f)
    
    return (residual ** 2).mean()


#########################
# PINN Loss Computation
#########################
def compute_pinn_residual(model: nn.Module, data: "PoissonData") -> tuple:
    """
    Compute PINN PDE residual using automatic differentiation.
    
    Args:
        model: Neural network model
        data: PoissonData instance
        
    Returns:
        (residual, u_pred): PDE residual and predicted solution
    """
    # Clone coordinates and enable gradient tracking
    xs = data.graph.x.clone().requires_grad_(True)
    
    # Create a temporary graph-like object with gradient-enabled coordinates
    class TempGraph:
        def __init__(self, x):
            self.x = x
    
    temp_graph = TempGraph(xs)
    u_pred = model(temp_graph)
    
    # Compute gradients: ∇u
    gradu = torch.autograd.grad(
        u_pred, xs, 
        create_graph=True, 
        grad_outputs=torch.ones_like(u_pred)
    )[0]
    
    grad_u_x = gradu[:, 0]
    grad_u_y = gradu[:, 1]
    
    # Compute second derivatives: ∂²u/∂x² and ∂²u/∂y²
    grad_u_xx = torch.autograd.grad(
        grad_u_x.sum(), xs, create_graph=True
    )[0][:, 0]
    grad_u_yy = torch.autograd.grad(
        grad_u_y.sum(), xs, create_graph=True
    )[0][:, 1]
    
    # PDE residual: -Δu = f  =>  f + (∂²u/∂x² + ∂²u/∂y²) = 0
    residual = data.f + (grad_u_xx + grad_u_yy)
    
    return residual, u_pred


def compute_pinn_loss(model: nn.Module, data: "PoissonData", lambda_bc: float = 10.0) -> torch.Tensor:
    """
    Compute PINN loss for a given model and data.
    
    Loss = pde_loss + lambda_bc * bc_loss
    
    Args:
        model: Neural network model
        data: PoissonData instance
        lambda_bc: Weight for boundary condition loss
        
    Returns:
        Total PINN loss
    """
    residual, u_pred = compute_pinn_residual(model, data)
    bd_mask = data.equation.boundary_mask
    interior_mask = ~bd_mask
    
    # PDE loss on interior points
    pde_loss = (residual[interior_mask] ** 2).mean()
    
    # Boundary condition loss (u = 0 on boundary)
    bc_loss = (u_pred[bd_mask] ** 2).mean()
    
    return pde_loss + lambda_bc * bc_loss


#########################
# VPINN Loss Computation
#########################
def compute_vpinn_residual(model: nn.Module, data: "PoissonData") -> tuple:
    """
    Compute VPINN residual using variational form with autograd gradients.
    
    Weak form: R^I = ∫ ∇u·∇N^I dx - ∫ f·N^I dx
    
    Key difference from GGL: gradients are computed via autograd, not FEM interpolation.
    Key difference from PINN: uses weak form (1st order derivatives) instead of strong form (2nd order).
    
    Args:
        model: Neural network model
        data: PoissonData instance
        
    Returns:
        (residual, u_pred): Variational residual and predicted solution
    """
    eq = data.equation
    
    # Ensure source function is float32
    f = data.f.float()  # [n_nodes]
    
    # Enable gradient tracking for coordinates
    coords = data.graph.x.clone().float().requires_grad_(True)  # [n_nodes, 2]
    
    # Forward pass with gradient-enabled coordinates
    class TempGraph:
        def __init__(self, x):
            self.x = x
    u_pred = model(TempGraph(coords))  # [n_nodes]
    
    # Compute ∇u at nodes via autograd (1st order only!)
    grad_u = torch.autograd.grad(
        u_pred.sum(), coords,
        create_graph=True
    )[0]  # [n_nodes, 2]
    
    # --- Gauss quadrature integration using FEM infrastructure ---
    # Extract element quantities
    elements = eq.elements  # [n_elem, n_basis]
    shape_val = eq.shape_val  # [n_quad, n_basis]
    shape_grad = eq.shape_grad  # [n_elem, n_quad, n_basis, 2]
    JxW = eq.JxW  # [n_elem, n_quad, 1]
    
    # Get nodal values for elements (ensure float32 for all tensors)
    elem_grad_u = grad_u[elements].float()  # [n_elem, n_basis, 2]
    elem_f = f[elements]  # [n_elem, n_basis] - already float32
    
    # Interpolate to quadrature points
    # ∇u at quad points: [n_elem, n_quad, 2]
    grad_u_quad = torch.einsum('gb,ebd->egd', shape_val.float(), elem_grad_u)
    # f at quad points: [n_elem, n_quad]
    f_quad = torch.einsum('gb,eb->eg', shape_val.float(), elem_f)
    
    # Compute weak form integrand at each quadrature point
    # Diffusion: ∇u · ∇N^I -> [n_elem, n_quad, n_basis]
    diffusion = torch.einsum('egd,egbd->egb', grad_u_quad, shape_grad.float())
    # Source: f · N^I -> [n_elem, n_quad, n_basis]
    source = f_quad.unsqueeze(-1) * shape_val.float().unsqueeze(0)
    
    # Integrate over quadrature points
    integrand = (diffusion - source) * JxW.float()  # [n_elem, n_quad, n_basis]
    elem_residual = integrand.sum(dim=1)  # [n_elem, n_basis]
    
    # Assemble to global residual
    residual = torch.zeros(data.graph.num_nodes, device=coords.device, dtype=coords.dtype)
    residual.scatter_add_(0, elements.flatten(), elem_residual.flatten())
    
    # Apply boundary conditions (set residual to 0 at boundary)
    residual[eq.boundary_mask] = 0.0
    
    return residual, u_pred


def compute_vpinn_loss(model: nn.Module, data: "PoissonData", lambda_bc: float = 10.0) -> torch.Tensor:
    """
    Compute VPINN loss = variational_loss + lambda_bc * bc_loss
    
    Args:
        model: Neural network model
        data: PoissonData instance
        lambda_bc: Weight for boundary condition loss
        
    Returns:
        Total VPINN loss
    """
    residual, u_pred = compute_vpinn_residual(model, data)
    bd_mask = data.equation.boundary_mask
    
    # Variational loss (on all nodes, boundary residual already set to 0)
    var_loss = (residual ** 2).mean()
    
    # Boundary condition loss (u = 0 on boundary) - soft constraint
    bc_loss = (u_pred[bd_mask] ** 2).mean()
    
    return var_loss + lambda_bc * bc_loss


#########################
# Deep Ritz Loss Computation
#########################
def compute_deepritz_energy(model: nn.Module, data: "PoissonData") -> tuple:
    """
    Compute Deep Ritz energy functional using autograd gradients.
    
    Energy: J(u) = (1/2) ∫|∇u|² dx - ∫f·u dx
    
    This directly minimizes the energy functional instead of residuals.
    
    Args:
        model: Neural network model
        data: PoissonData instance
        
    Returns:
        (energy, u_pred): Energy value and predicted solution
    """
    eq = data.equation
    
    # Ensure source function is float32
    f = data.f.float()
    
    # Enable gradient tracking for coordinates
    coords = data.graph.x.clone().float().requires_grad_(True)
    
    # Forward pass with gradient-enabled coordinates
    class TempGraph:
        def __init__(self, x):
            self.x = x
    u_pred = model(TempGraph(coords))  # [n_nodes]
    
    # Compute ∇u at nodes via autograd (1st order only!)
    grad_u = torch.autograd.grad(
        u_pred.sum(), coords,
        create_graph=True
    )[0]  # [n_nodes, 2]
    
    # --- Gauss quadrature integration using FEM infrastructure ---
    elements = eq.elements  # [n_elem, n_basis]
    shape_val = eq.shape_val.float()  # [n_quad, n_basis]
    JxW = eq.JxW.float()  # [n_elem, n_quad, 1]
    
    # Get nodal values for elements (ensure float32)
    elem_grad_u = grad_u[elements].float()  # [n_elem, n_basis, 2]
    elem_u = u_pred[elements].float()  # [n_elem, n_basis]
    elem_f = f[elements]  # [n_elem, n_basis]
    
    # Interpolate to quadrature points
    # ∇u at quad points: [n_elem, n_quad, 2]
    grad_u_quad = torch.einsum('gb,ebd->egd', shape_val, elem_grad_u)
    # u at quad points: [n_elem, n_quad]
    u_quad = torch.einsum('gb,eb->eg', shape_val, elem_u)
    # f at quad points: [n_elem, n_quad]
    f_quad = torch.einsum('gb,eb->eg', shape_val, elem_f)
    
    # Compute energy density at each quadrature point
    # (1/2)|∇u|² - f·u
    grad_u_sq = (grad_u_quad ** 2).sum(dim=-1)  # [n_elem, n_quad]
    energy_density = 0.5 * grad_u_sq - f_quad * u_quad  # [n_elem, n_quad]
    
    # Integrate over domain: ∫ energy_density dΩ
    energy = (energy_density.unsqueeze(-1) * JxW).sum()
    
    return energy, u_pred


def compute_deepritz_loss(model: nn.Module, data: "PoissonData", lambda_bc: float = 10.0) -> torch.Tensor:
    """
    Compute Deep Ritz loss = energy + lambda_bc * bc_loss
    
    Note: Deep Ritz needs larger BC weight than PINN/VPINN because the energy
    can be made arbitrarily negative by making u large in the interior.
    
    Args:
        model: Neural network model
        data: PoissonData instance
        lambda_bc: Weight for boundary condition loss (default 500, larger than PINN)
        
    Returns:
        Total Deep Ritz loss
    """
    energy, u_pred = compute_deepritz_energy(model, data)
    bd_mask = data.equation.boundary_mask
    
    # Boundary condition loss (u = 0 on boundary) - soft constraint
    bc_loss = (u_pred[bd_mask] ** 2).mean()
    
    return energy + lambda_bc * bc_loss


def compute_hessian_eigenvectors(
    model: nn.Module, 
    data: "PoissonData", 
    top_n: int = 2, 
    max_iter: int = 100
) -> tuple:
    """
    Compute top-n Hessian eigenvectors using power iteration.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        top_n: Number of top eigenvectors to compute
        max_iter: Maximum iterations for power iteration
        
    Returns:
        (eigenvalues, eigenvectors): Lists of eigenvalues and eigenvector GroupTensors
    """
    # Compute loss and gradients
    loss = compute_galerkin_loss(model, data)
    loss.backward(create_graph=True)
    
    params = GroupTensors([p for p in model.parameters()])
    gradsH = GroupTensors([p.grad for p in model.parameters()])
    
    eigenvalues = []
    eigenvectors = []
    
    for i_eigen in range(top_n):
        last_eigenvalue = None
        v = GroupTensors.randn_like(params)
        
        for i_iter in range(max_iter):
            model.zero_grad()
            v = v.orth(eigenvectors).normalize()
            
            # Compute Hessian-vector product
            Hv = GroupTensors(torch.autograd.grad(
                gradsH.tensors, params.tensors, 
                grad_outputs=v.tensors, 
                only_inputs=True, retain_graph=True
            ))
            eigenvalue = Hv @ v
            
            if last_eigenvalue is None:
                last_eigenvalue = eigenvalue
            else:
                relative_residual = abs(eigenvalue - last_eigenvalue) / (abs(eigenvalue) + 1e-6)
                if relative_residual < 1e-4:
                    break
                last_eigenvalue = eigenvalue
        
        print(f"  Eigenvector {i_eigen}: converged at iter {i_iter}, eigenvalue={eigenvalue:.4e}")
        eigenvalues.append(last_eigenvalue)
        eigenvectors.append(Hv.normalize())
    
    return eigenvalues, eigenvectors


def plot_landscape(
    x: np.ndarray, 
    y: np.ndarray, 
    z: np.ndarray, 
    title: str, 
    out_prefix: str, 
    use_log: bool = True, 
    levels: int = 40
):
    """
    Plot loss landscape as 3D surface and 2D contour.
    
    Args:
        x, y, z: Meshgrid arrays for plotting
        title: Plot title
        out_prefix: Output filename prefix
        use_log: Whether to use log scale for z-axis
        levels: Number of contour levels
    """
    z = np.asarray(z, dtype=float)
    eps = max(1e-12, 1e-6 * np.nanmedian(z[z > 0])) if np.any(z > 0) else 1e-12
    z_safe = z.copy()
    z_safe[z_safe <= eps] = eps  # avoid log(0)

    if use_log and np.any(z_safe > 0):
        vmin = np.percentile(z_safe[z_safe > 0], 1)
        vmax = np.percentile(z_safe, 99)
        vmin = max(vmin, eps)
        norm = LogNorm(vmin=vmin, vmax=vmax)
        zlim = (vmin, vmax)
    else:
        norm = None
        zlim = (np.percentile(z_safe, 1), np.percentile(z_safe, 99))

    # 3D surface plot
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(x, y, z_safe, cmap='viridis', rstride=1, cstride=1, linewidth=0, antialiased=False)
    ax.set_xlabel('Hessian Eigenvector 1')
    ax.set_ylabel('Hessian Eigenvector 2')
    ax.set_zlabel('Loss')
    ax.set_title(title)
    if use_log:
        ax.set_zscale('log')
        ax.set_zlim(*zlim)
        ax.zaxis.set_major_locator(LogLocator(base=10))
        ax.zaxis.set_major_formatter(LogFormatterSciNotation())
    else:
        ax.set_zlim(*zlim)
        ax.zaxis.set_major_formatter(ScalarFormatter())
    fig.colorbar(surf, ax=ax, shrink=0.7)
    
    os.makedirs("output/landscape", exist_ok=True)
    fig.savefig(f"output/landscape/{out_prefix}_surface.png", dpi=200)
    plt.close(fig)

    # 2D contour plot
    fig2, ax2 = plt.subplots(1, 1, figsize=(6, 5))
    cf = ax2.contourf(x, y, z_safe, levels=levels, cmap='viridis', norm=norm)
    ax2.set_xlabel('Hessian Eigenvector 1')
    ax2.set_ylabel('Hessian Eigenvector 2')
    ax2.set_title(title + " (contour)")
    ax2.plot(0, 0, 'r*', markersize=15, label='Optimum')  # Mark the trained model position
    ax2.legend()
    cbar = fig2.colorbar(cf, ax=ax2)
    cbar.set_label('Loss')
    fig2.savefig(f"output/landscape/{out_prefix}_contour.png", dpi=200, bbox_inches='tight')
    plt.close(fig2)
    
    print(f"  Loss landscape saved to output/landscape/{out_prefix}_*.png")


def compute_galerkin_loss_landscape(
    model: nn.Module,
    data: "PoissonData",
    config: "TrainConfig",
    n_points: int = 20,
    scale: float = 1.0,
):
    """
    Compute and plot loss landscape around a trained model.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        config: Training configuration (for naming)
        n_points: Grid resolution (points per axis)
        scale: Scan radius along eigenvector directions
    """
    print("\n" + "=" * 50)
    print("Computing Loss Landscape")
    print("=" * 50)
    
    # ===== DEBUG: 在计算 Hessian 之前检查 loss =====
    with torch.no_grad():
        loss_before = compute_galerkin_loss(model, data)
        print(f"[DEBUG] Loss BEFORE Hessian computation: {loss_before.item():.6e}")


    # Step 1: Compute Hessian eigenvectors
    print("\n[1/3] Computing Hessian eigenvectors...")
    eigenvalues, eigenvectors = compute_hessian_eigenvectors(model, data, top_n=2)

    # ===== DEBUG: 在计算 Hessian 之后检查 loss =====
    with torch.no_grad():
        loss_after = compute_galerkin_loss(model, data)
        print(f"[DEBUG] Loss AFTER Hessian computation: {loss_after.item():.6e}")


    xvector = eigenvectors[0]
    yvector = eigenvectors[1]
    
    # Step 2: Store original parameters
    print("\n[2/3] Scanning loss landscape...")
    params = deepcopy(GroupTensors([p for p in model.parameters()]))
    
    # Create meshgrid
    x, y = torch.meshgrid(
        torch.linspace(-scale, scale, n_points),
        torch.linspace(-scale, scale, n_points),
        indexing='ij'
    )
    z = torch.zeros_like(x)
    
    # Compute loss at each grid point
    with torch.no_grad():
        loss_after_copy = compute_galerkin_loss(model, data)
        print(f"[DEBUG] Loss after deepcopy: {loss_after_copy.item():.6e}")

        for i, j in tqdm(product(range(n_points), range(n_points)), 
                         total=n_points * n_points, 
                         desc="  Scanning", unit="point"):
            # Perturb parameters
            new_params = params + x[i, j] * xvector + y[i, j] * yvector
            for param, new_param in zip(model.parameters(), new_params.tensors):
                param.data = new_param
            
            # Compute loss (without gradient)
            loss = compute_galerkin_loss(model, data)
            z[i, j] = loss.item()

            # ===== DEBUG: 检查中心点附近的值 =====
            if i == n_points // 2 and j == n_points // 2:
                print(f"[DEBUG] Grid center ({i},{j}): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={loss.item():.6e}")
            if i == 0 and j == 0:
                print(f"[DEBUG] Grid corner (0,0): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={loss.item():.6e}")

    
    # Restore original parameters
    for param, orig_param in zip(model.parameters(), params.tensors):
        param.data = orig_param
    
    with torch.no_grad():
        loss_at_origin = compute_galerkin_loss(model, data)
        print(f"[DEBUG] Loss at TRUE origin (0, 0): {loss_at_origin.item():.6e}")
    # Step 3: Plot
    print("\n[3/3] Plotting loss landscape...")
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()
    z_np = z.cpu().numpy()
    
    out_prefix = f"ggl_K{config.K}_{config.model}"
    plot_landscape(
        x_np, y_np, z_np,
        title=f'GGL Loss Landscape (K={config.K}, model={config.model})',
        out_prefix=out_prefix,
        use_log=True
    )
    
    print(f"\nLoss at optimum: {z_np[n_points//2, n_points//2]:.4e}")
    print(f"Min loss in landscape: {z_np.min():.4e}")
    print(f"Max loss in landscape: {z_np.max():.4e}")
    print("=" * 50)


def compute_pinn_hessian_eigenvectors(
    model: nn.Module, 
    data: "PoissonData", 
    top_n: int = 2, 
    max_iter: int = 100,
    lambda_bc: float = 10.0
) -> tuple:
    """
    Compute top-n Hessian eigenvectors for PINN loss using power iteration.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        top_n: Number of top eigenvectors to compute
        max_iter: Maximum iterations for power iteration
        lambda_bc: Weight for boundary condition loss
        
    Returns:
        (eigenvalues, eigenvectors): Lists of eigenvalues and eigenvector GroupTensors
    """
    # Compute PINN loss and gradients
    loss = compute_pinn_loss(model, data, lambda_bc)
    loss.backward(create_graph=True)
    
    params = GroupTensors([p for p in model.parameters()])
    gradsH = GroupTensors([p.grad for p in model.parameters()])
    
    eigenvalues = []
    eigenvectors = []
    
    for i_eigen in range(top_n):
        last_eigenvalue = None
        v = GroupTensors.randn_like(params)
        
        for i_iter in range(max_iter):
            model.zero_grad()
            v = v.orth(eigenvectors).normalize()
            
            # Compute Hessian-vector product
            Hv = GroupTensors(torch.autograd.grad(
                gradsH.tensors, params.tensors, 
                grad_outputs=v.tensors, 
                only_inputs=True, retain_graph=True
            ))
            eigenvalue = Hv @ v
            
            if last_eigenvalue is None:
                last_eigenvalue = eigenvalue
            else:
                relative_residual = abs(eigenvalue - last_eigenvalue) / (abs(eigenvalue) + 1e-6)
                if relative_residual < 1e-4:
                    break
                last_eigenvalue = eigenvalue
        
        print(f"  Eigenvector {i_eigen}: converged at iter {i_iter}, eigenvalue={eigenvalue:.4e}")
        eigenvalues.append(last_eigenvalue)
        eigenvectors.append(Hv.normalize())
    
    return eigenvalues, eigenvectors


def compute_pinn_loss_landscape(
    model: nn.Module,
    data: "PoissonData",
    config: "TrainConfig",
    n_points: int = 21,
    scale: float = 1.0,
    lambda_bc: float = 10.0,
):
    """
    Compute and plot PINN loss landscape around a trained model.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        config: Training configuration (for naming)
        n_points: Grid resolution (points per axis)
        scale: Scan radius along eigenvector directions
        lambda_bc: Weight for boundary condition loss
    """
    print("\n" + "=" * 50)
    print("Computing PINN Loss Landscape")
    print("=" * 50)
    
    # NOTE: PINN loss requires gradients for computing Laplacian (second derivatives)
    # So we cannot use torch.no_grad() for PINN loss computation
    
    # DEBUG: Check loss before Hessian computation
    loss_before = compute_pinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss BEFORE Hessian computation: {loss_before.item():.6e}")
    model.zero_grad()  # Clear gradients after loss computation
    
    # Step 1: Compute Hessian eigenvectors
    print("\n[1/3] Computing Hessian eigenvectors...")
    eigenvalues, eigenvectors = compute_pinn_hessian_eigenvectors(model, data, top_n=2, lambda_bc=lambda_bc)
    
    # DEBUG: Check loss after Hessian computation
    loss_after = compute_pinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss AFTER Hessian computation: {loss_after.item():.6e}")
    model.zero_grad()
    
    xvector = eigenvectors[0]
    yvector = eigenvectors[1]
    
    # Step 2: Store original parameters
    print("\n[2/3] Scanning loss landscape...")
    params = deepcopy(GroupTensors([p for p in model.parameters()]))
    
    # Create meshgrid
    x, y = torch.meshgrid(
        torch.linspace(-scale, scale, n_points),
        torch.linspace(-scale, scale, n_points),
        indexing='ij'
    )
    z = torch.zeros_like(x)
    
    # Compute loss at each grid point
    # PINN requires gradients for Laplacian computation, so we use enable_grad
    loss_after_copy = compute_pinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss after deepcopy: {loss_after_copy.item():.6e}")
    model.zero_grad()
    
    for i, j in tqdm(product(range(n_points), range(n_points)), 
                     total=n_points * n_points, 
                     desc="  Scanning", unit="point"):
        # Perturb parameters
        new_params = params + x[i, j] * xvector + y[i, j] * yvector
        for param, new_param in zip(model.parameters(), new_params.tensors):
            param.data = new_param.detach()
        
        # Compute PINN loss (requires gradient for Laplacian)
        loss = compute_pinn_loss(model, data, lambda_bc)
        z[i, j] = loss.detach().item()
        model.zero_grad()  # Clear gradients to avoid memory buildup
        
        # DEBUG: Check center and corner points
        if i == n_points // 2 and j == n_points // 2:
            print(f"[DEBUG] Grid center ({i},{j}): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={z[i,j]:.6e}")
        if i == 0 and j == 0:
            print(f"[DEBUG] Grid corner (0,0): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={z[i,j]:.6e}")
    
    # Restore original parameters
    for param, orig_param in zip(model.parameters(), params.tensors):
        param.data = orig_param.detach()
    
    loss_at_origin = compute_pinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss at TRUE origin (0, 0): {loss_at_origin.item():.6e}")
    model.zero_grad()
    
    # Step 3: Plot
    print("\n[3/3] Plotting loss landscape...")
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()
    z_np = z.cpu().numpy()
    
    out_prefix = f"pinn_K{config.K}_{config.model}"
    plot_landscape(
        x_np, y_np, z_np,
        title=f'PINN Loss Landscape (K={config.K}, model={config.model})',
        out_prefix=out_prefix,
        use_log=True
    )
    
    print(f"\nLoss at optimum: {z_np[n_points//2, n_points//2]:.4e}")
    print(f"Min loss in landscape: {z_np.min():.4e}")
    print(f"Max loss in landscape: {z_np.max():.4e}")
    print("=" * 50)


def compute_vpinn_hessian_eigenvectors(
    model: nn.Module, 
    data: "PoissonData", 
    top_n: int = 2, 
    max_iter: int = 100,
    lambda_bc: float = 10.0
) -> tuple:
    """
    Compute top-n Hessian eigenvectors for VPINN loss using power iteration.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        top_n: Number of top eigenvectors to compute
        max_iter: Maximum iterations for power iteration
        lambda_bc: Weight for boundary condition loss
        
    Returns:
        (eigenvalues, eigenvectors): Lists of eigenvalues and eigenvector GroupTensors
    """
    # Compute VPINN loss and gradients
    loss = compute_vpinn_loss(model, data, lambda_bc)
    loss.backward(create_graph=True)
    
    params = GroupTensors([p for p in model.parameters()])
    gradsH = GroupTensors([p.grad for p in model.parameters()])
    
    eigenvalues = []
    eigenvectors = []
    
    for i_eigen in range(top_n):
        last_eigenvalue = None
        v = GroupTensors.randn_like(params)
        
        for i_iter in range(max_iter):
            model.zero_grad()
            v = v.orth(eigenvectors).normalize()
            
            # Compute Hessian-vector product
            Hv = GroupTensors(torch.autograd.grad(
                gradsH.tensors, params.tensors, 
                grad_outputs=v.tensors, 
                only_inputs=True, retain_graph=True
            ))
            eigenvalue = Hv @ v
            
            if last_eigenvalue is None:
                last_eigenvalue = eigenvalue
            else:
                relative_residual = abs(eigenvalue - last_eigenvalue) / (abs(eigenvalue) + 1e-6)
                if relative_residual < 1e-4:
                    break
                last_eigenvalue = eigenvalue
        
        print(f"  Eigenvector {i_eigen}: converged at iter {i_iter}, eigenvalue={eigenvalue:.4e}")
        eigenvalues.append(last_eigenvalue)
        eigenvectors.append(Hv.normalize())
    
    return eigenvalues, eigenvectors


def compute_vpinn_loss_landscape(
    model: nn.Module,
    data: "PoissonData",
    config: "TrainConfig",
    n_points: int = 21,
    scale: float = 1.0,
    lambda_bc: float = 10.0,
):
    """
    Compute and plot VPINN loss landscape around a trained model.
    
    Args:
        model: Trained neural network
        data: PoissonData instance
        config: Training configuration (for naming)
        n_points: Grid resolution (points per axis)
        scale: Scan radius along eigenvector directions
        lambda_bc: Weight for boundary condition loss
    """
    print("\n" + "=" * 50)
    print("Computing VPINN Loss Landscape")
    print("=" * 50)
    
    # DEBUG: Check loss before Hessian computation
    loss_before = compute_vpinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss BEFORE Hessian computation: {loss_before.item():.6e}")
    model.zero_grad()
    
    # Step 1: Compute Hessian eigenvectors
    print("\n[1/3] Computing Hessian eigenvectors...")
    eigenvalues, eigenvectors = compute_vpinn_hessian_eigenvectors(model, data, top_n=2, lambda_bc=lambda_bc)
    
    # DEBUG: Check loss after Hessian computation
    loss_after = compute_vpinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss AFTER Hessian computation: {loss_after.item():.6e}")
    model.zero_grad()
    
    xvector = eigenvectors[0]
    yvector = eigenvectors[1]
    
    # Step 2: Store original parameters
    print("\n[2/3] Scanning loss landscape...")
    params = deepcopy(GroupTensors([p for p in model.parameters()]))
    
    # Create meshgrid
    x, y = torch.meshgrid(
        torch.linspace(-scale, scale, n_points),
        torch.linspace(-scale, scale, n_points),
        indexing='ij'
    )
    z = torch.zeros_like(x)
    
    # Compute loss at each grid point
    loss_after_copy = compute_vpinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss after deepcopy: {loss_after_copy.item():.6e}")
    model.zero_grad()
    
    for i, j in tqdm(product(range(n_points), range(n_points)), 
                     total=n_points * n_points, 
                     desc="  Scanning", unit="point"):
        # Perturb parameters
        new_params = params + x[i, j] * xvector + y[i, j] * yvector
        for param, new_param in zip(model.parameters(), new_params.tensors):
            param.data = new_param.detach()
        
        # Compute VPINN loss
        loss = compute_vpinn_loss(model, data, lambda_bc)
        z[i, j] = loss.detach().item()
        model.zero_grad()
        
        # DEBUG: Check center and corner points
        if i == n_points // 2 and j == n_points // 2:
            print(f"[DEBUG] Grid center ({i},{j}): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={z[i,j]:.6e}")
        if i == 0 and j == 0:
            print(f"[DEBUG] Grid corner (0,0): x={x[i,j].item():.4f}, y={y[i,j].item():.4f}, loss={z[i,j]:.6e}")
    
    # Restore original parameters
    for param, orig_param in zip(model.parameters(), params.tensors):
        param.data = orig_param.detach()
    
    loss_at_origin = compute_vpinn_loss(model, data, lambda_bc)
    print(f"[DEBUG] Loss at TRUE origin (0, 0): {loss_at_origin.item():.6e}")
    model.zero_grad()
    
    # Step 3: Plot
    print("\n[3/3] Plotting loss landscape...")
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()
    z_np = z.cpu().numpy()
    
    out_prefix = f"vpinn_K{config.K}_{config.model}"
    plot_landscape(
        x_np, y_np, z_np,
        title=f'VPINN Loss Landscape (K={config.K}, model={config.model})',
        out_prefix=out_prefix,
        use_log=True
    )
    
    print(f"\nLoss at optimum: {z_np[n_points//2, n_points//2]:.4e}")
    print(f"Min loss in landscape: {z_np.min():.4e}")
    print(f"Max loss in landscape: {z_np.max():.4e}")
    print("=" * 50)


#########################
# Data Generation and Training
#########################

@dataclass 
class TrainConfig:
    K:int = 1
    r:float = 1
    nx:int = 10
    optimizer:str = "adam"
    lr:float = 0.01
    lr_min:float = 1e-6  # 终止学习率 (用于 cosine 衰减)
    epochs:int = 100
    lbfgs_epochs:int = 200  # L-BFGS 单独的 epochs 数
    model:str = "mock"
    source_type:str = "analytical"  # "analytical", "abs_sin", "step", "circle", "checkerboard", "constant"
    domain:str = "rectangle"  # "rectangle", "L_shape", "circle"

    def __str__(self):
        return f"dom{self.domain}_src{self.source_type}_K{self.K}_opt{self.optimizer}_lr{self.lr}_ep{self.epochs}_lbfgs{self.lbfgs_epochs}_m{self.model}_r{self.r}"


class PoissonData:

    def __init__(self, config:TrainConfig):
        # Generate mesh based on domain type
        if config.domain == "rectangle":
            self.mesh = Gmsh.gen_rectangle(chara_length=0.02)
        elif config.domain == "L_shape":
            # Use finer mesh for L-shape to capture corner singularity
            self.mesh = Gmsh.gen_L_shape(chara_length=0.02)
        elif config.domain == "circle":
            self.mesh = Gmsh.gen_cirlce(chara_length=0.02)
        else:
            raise ValueError(f"Unknown domain: {config.domain}. "
                           f"Available: rectangle, L_shape, circle")
        
        self.mesh.point_data['boundary_value'] = np.zeros_like(self.mesh.point_data['boundary_mask'])
        self.equation = PoissonEquation(self.mesh)
        
        # Generate source function and solution based on source_type
        if config.source_type == "analytical":
            # Smooth analytical solution (only works well for rectangle domain)
            a = torch.rand(config.K, config.K).numpy() * 2 - 1
            f_np = PoissonGen.MultiAnalytical.source(self.mesh.points, a, r=config.r)
            u_np = PoissonGen.MultiAnalytical.solution(self.mesh.points, a, r=config.r)
        else:
            # Discontinuous or simple source functions - use FEM for ground truth
            if config.source_type == "abs_sin":
                f_np = PoissonGen.Discontinuous.abs_sin(self.mesh.points, k=config.K)
            elif config.source_type == "step":
                f_np = PoissonGen.Discontinuous.step(self.mesh.points, x0=0.5)
            elif config.source_type == "indicator_circle":
                f_np = PoissonGen.Discontinuous.indicator_circle(self.mesh.points, cx=0.5, cy=0.5, r=0.25)
            elif config.source_type == "checkerboard":
                f_np = PoissonGen.Discontinuous.checkerboard(self.mesh.points, n=config.K)
            elif config.source_type == "constant":
                f_np = PoissonGen.Discontinuous.constant(self.mesh.points, value=1.0)
            else:
                raise ValueError(f"Unknown source_type: {config.source_type}. "
                               f"Available: analytical, abs_sin, step, indicator_circle, checkerboard, constant")
            # Use FEM to compute ground truth solution
            u_np = PoissonGen.Random.solution(self.mesh, f_np)
        
        self.f = torch.from_numpy(f_np).float()
        self.u = torch.from_numpy(u_np).float()
        self.graph = mesh_to_pyg_graph(self.mesh)
        self.graph.u = self.u  # Add ground truth to graph


@dataclass 
class Stat:
    data: PoissonData
    losses: List[float] = field(default_factory=list)   
    mse: List[float] = field(default_factory=list)
    us: List[np.ndarray] = field(default_factory=list)


    def record(
        self, 
        loss:float, 
        mse:float, 
        prediction:torch.Tensor,
    ):
        self.losses.append(loss)
        self.mse.append(mse)
        self.us.append(prediction.cpu().numpy())

    def plot(self, config:TrainConfig, method: str = "ggl"):
        output_path = f"output/training_results/{method}/{config}.png"
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # ICML style settings
        plt.rcParams.update({
            'font.family': 'serif',
            'font.serif': ['DejaVu Serif', 'Liberation Serif'],
            'font.size': 10,
            'axes.labelsize': 10,
            'axes.titlesize': 10,
            'xtick.labelsize': 9,
            'ytick.labelsize': 9,
            'legend.fontsize': 9,
            'figure.titlesize': 11,
            'lines.linewidth': 1.5,
            'axes.linewidth': 0.8,
            'grid.linewidth': 0.5,
            'grid.alpha': 0.3,
        })

        fig, axes = plt.subplots(1, 2, figsize=(7, 2.5))

        # Plot Loss
        axes[0].plot(self.losses, color='#1f77b4', linewidth=1.5)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, linestyle='--', alpha=0.3)
        axes[0].spines['top'].set_visible(False)
        axes[0].spines['right'].set_visible(False)
        # Deep Ritz can have negative loss (energy), so don't use log scale
        if method != "deepritz":
            axes[0].set_yscale('log')

        # Plot MSE
        axes[1].plot(self.mse, color='#ff7f0e', linewidth=1.5)
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('MSE')
        axes[1].set_title('Mean Squared Error')
        axes[1].grid(True, linestyle='--', alpha=0.3)
        axes[1].spines['top'].set_visible(False)
        axes[1].spines['right'].set_visible(False)
        axes[1].set_yscale('log')
        plt.suptitle(config)
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()

        output_path.replace(".png", "_final.png")

        # Plot f, u, and final prediction
        fig, axes = plt.subplots(1, 3, figsize=(10.5, 2.8))

        # Get the best prediction (lowest loss) instead of last epoch
        best_idx = np.argmin(self.losses)
        best_prediction = self.us[best_idx].squeeze()
        best_loss = self.losses[best_idx]
        best_mse = self.mse[best_idx]
        print(f"Best model at epoch {best_idx}: loss={best_loss:.6e}, mse={best_mse:.6e}")

        # Get mesh points and elements for plotting
        points = self.data.mesh.points[:, :2]  # Extract x, y coordinates

        # Get triangulation from mesh
        if "triangle" in self.data.mesh.cells_dict:
            triangles = self.data.mesh.cells_dict["triangle"]
        elif "quad" in self.data.mesh.cells_dict:
            # Convert quads to triangles for plotting
            quads = self.data.mesh.cells_dict["quad"]
            triangles = np.vstack([
                quads[:, [0, 1, 2]],
                quads[:, [0, 2, 3]]
            ])
        else:
            raise ValueError("Unsupported cell type")

        # Plot source function f (已经是 torch tensor)
        im0 = axes[0].tripcolor(points[:, 0], points[:, 1], triangles, self.data.f.cpu().numpy(), shading='flat', cmap='jet')
        axes[0].triplot(points[:, 0], points[:, 1], triangles, 'k-', linewidth=0.3, alpha=0.3)
        axes[0].set_xlabel('x')
        axes[0].set_ylabel('y')
        axes[0].set_title('Source Function f')
        axes[0].set_aspect('equal')
        plt.colorbar(im0, ax=axes[0])

        # Plot true solution u
        im1 = axes[1].tripcolor(points[:, 0], points[:, 1], triangles, self.data.u.cpu().numpy(), shading='flat', cmap='jet')
        axes[1].triplot(points[:, 0], points[:, 1], triangles, 'k-', linewidth=0.3, alpha=0.3)
        axes[1].set_xlabel('x')
        axes[1].set_ylabel('y')
        axes[1].set_title('True Solution u')
        axes[1].set_aspect('equal')
        plt.colorbar(im1, ax=axes[1])

        # Plot best prediction (lowest loss)
        im2 = axes[2].tripcolor(points[:, 0], points[:, 1], triangles, best_prediction, shading='flat', cmap='jet')
        axes[2].triplot(points[:, 0], points[:, 1], triangles, 'k-', linewidth=0.3, alpha=0.3)
        axes[2].set_xlabel('x')
        axes[2].set_ylabel('y')
        axes[2].set_title(f'Best Prediction (epoch {best_idx})')
        axes[2].set_aspect('equal')
        plt.colorbar(im2, ax=axes[2])

        plt.suptitle(config)

        plt.tight_layout()
        output_path_final = output_path.replace(".png", "_final.png")
        plt.savefig(output_path_final, dpi=300, bbox_inches='tight')
        plt.close()

        # save us
        np.save(output_path.replace(".png", "_us.npy"), self.us)

#########################
# Training
#########################


def train_ggl(data:PoissonData, config:TrainConfig = TrainConfig()) -> tuple:
    """
    Train using Graph Galerkin Learning.
    
    Returns:
        (model, stat): Trained model and training statistics
    """
    if config.model == "mock":
        model = Mock(data.graph.num_nodes)
    elif config.model == "mlp":
        model = MLP(in_channels=3)  # Use [x, y, f] as input
        model.set_f(data.f)  # Set source function
    elif config.model == "mlp2":
        model = MLP(in_channels=2)  # Use [x, y] only (original behavior)
    elif config.model == "resmlp":
        model = ResidualMLP()
    elif config.model == "sign":
        model = SIGNWrapper(num_hidden=64, num_layers=3, num_hops=3)
        model.set_data(data.f, data.graph.edge_index, data.graph.x)  # x contains (x, y) coordinates
    else:
        raise ValueError(f"Unknown model {config.model}. Available: mock, mlp, mlp2, resmlp, sign")
    bd_mask = data.equation.boundary_mask  # 边界节点 mask

    stat = Stat(data)


    if config.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        # Cosine 学习率衰减: lr 从 config.lr 衰减到 config.lr_min
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        with tqdm(range(config.epochs)) as pbar:
            for i in pbar:
                optimizer.zero_grad()
                u = model(data.graph)  # [n_nodes, 1]
                u_flat = u.squeeze(-1)  # [n_nodes]

                # ⚠️ 关键：强制边界条件 u=0
                u_bc = u_flat.clone()
                # u_bc[bd_mask] = 0.0

                # 使用 compute_residual_fast 计算 Galerkin 残差 (R = Au - b)
                residual = data.equation.compute_residual_fast(u_bc, data.f)
                galerkin_loss = (residual ** 2).mean()

                galerkin_loss.backward()
                
                # Compute gradient norm before clipping
                grad_norm = compute_grad_norm(model)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()  # 更新学习率

                # 计算 MSE (用应用边界条件后的 u)
                mse = nn.MSELoss()(u_bc, data.u)
                stat.record(galerkin_loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(loss=f"{galerkin_loss.item():.2e}", mse=f"{mse.item():.2e}", lr=f"{current_lr:.2e}", grad=f"{grad_norm:.2e}")
        stat.plot(config, method="ggl")
        return model, stat

    elif config.optimizer == "adam+lbfgs":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        with tqdm(range(config.epochs), desc="Adam") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                u = model(data.graph)  # [n_nodes, 1]
                u_flat = u.squeeze(-1)  # [n_nodes]

                # ⚠️ 关键：强制边界条件 u=0
                u_bc = u_flat.clone()
                # u_bc[bd_mask] = 0.0

                # 使用 compute_residual_fast 计算 Galerkin 残差 (R = Au - b)
                residual = data.equation.compute_residual_fast(u_bc, data.f)
                galerkin_loss = (residual ** 2).mean()

                galerkin_loss.backward()
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(model)
                
                optimizer.step()
                scheduler.step()  # 更新学习率

                # 计算 MSE (用应用边界条件后的 u)
                mse = nn.MSELoss()(u_bc, data.u)
                stat.record(galerkin_loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(loss=f"{galerkin_loss.item():.2e}", mse=f"{mse.item():.2e}", lr=f"{current_lr:.2e}", grad=f"{grad_norm:.2e}")
        

        # Phase 2: L-BFGS with proper settings
        optimizer = torch.optim.LBFGS(
            model.parameters(), 
            lr=1.0,           # L-BFGS typically uses lr=1.0
            max_iter=20,      # Max iterations per step() call
            history_size=50,
            line_search_fn='strong_wolfe'
        )
        
        with tqdm(range(config.lbfgs_epochs), desc="LBFGS") as pbar:
            for _ in pbar:
                def closure():
                    optimizer.zero_grad()
                    u = model(data.graph)
                    u_flat = u.squeeze(-1)
                    
                    u_bc = u_flat.clone()
                    
                    # 使用 compute_residual_fast (R = Au - b)
                    residual = data.equation.compute_residual_fast(u_bc, data.f)
                    galerkin_loss = (residual ** 2).mean()
                    
                    galerkin_loss.backward()
                    return galerkin_loss
                
                loss = optimizer.step(closure)
                
                # Record stats outside closure
                with torch.no_grad():
                    u = model(data.graph)
                    u_flat = u.squeeze(-1)
                    u_bc = u_flat.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                pbar.set_postfix(loss=f"{loss.item():.2e}", mse=f"{mse.item():.2e}")

        stat.plot(config, method="ggl")
        return model, stat


def train_pinn(data: PoissonData, config: TrainConfig = TrainConfig()) -> tuple:
    """
    Train using PINN (Physics-Informed Neural Networks).
    
    Returns:
        (model, stat): Trained model and training statistics
    """
    if config.model == "sign":
        raise ValueError("SIGN model is not compatible with PINN. "
                        "PINN requires coordinate-based models (mlp, mlp2, resmlp) for auto-differentiation. "
                        "Use --method ggl for SIGN model.")
    
    if config.model == "mock":
        model = Mock(data.graph.num_nodes)
    elif config.model == "mlp":
        # For PINN, mlp uses [x, y, f] but f is set externally
        model = MLP(in_channels=3)
        model.set_f(data.f)
    elif config.model == "mlp2":
        model = MLP(in_channels=2)  # Use [x, y] only
    elif config.model == "resmlp":
        model = ResidualMLP()
    else:
        raise ValueError(f"Unknown model {config.model}. Available: mock, mlp, mlp2, resmlp")
    
    bd_mask = data.equation.boundary_mask
    interior_mask = ~bd_mask
    lambda_bc = 10.0  # Boundary condition weight
    
    stat = Stat(data)
    
    if config.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="PINN[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                # Compute PINN residual with gradient tracking
                residual, u_pred = compute_pinn_residual(model, data)
                
                # PDE loss on interior points
                pde_loss = (residual[interior_mask] ** 2).mean()

                # Boundary condition loss
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                
                # Total loss
                loss = pde_loss + lambda_bc * bc_loss
                
                loss.backward()
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(model)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                
                # Compute MSE with ground truth
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}", 
                    pde=f"{pde_loss.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}", 
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        stat.plot(config, method="pinn")
        return model, stat
    
    elif config.optimizer == "adam+lbfgs":
        # Phase 1: Adam
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="PINN[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                residual, u_pred = compute_pinn_residual(model, data)
                pde_loss = (residual[interior_mask] ** 2).mean()
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                loss = pde_loss + lambda_bc * bc_loss
                
                loss.backward()
                grad_norm = compute_grad_norm(model)
                
                optimizer.step()
                scheduler.step()
                
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}",
                    pde=f"{pde_loss.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}",
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        # Phase 2: L-BFGS
        optimizer = torch.optim.LBFGS(
            model.parameters(), 
            lr=1.0, 
            max_iter=20,
            history_size=50,
            line_search_fn="strong_wolfe"
        )

        with tqdm(range(config.lbfgs_epochs), desc="PINN[LBFGS]") as pbar:
            for _ in pbar:
                def closure():
                    optimizer.zero_grad()
                    residual, u_pred = compute_pinn_residual(model, data)
                    pde_loss = (residual[interior_mask] ** 2).mean()
                    bc_loss = (u_pred[bd_mask] ** 2).mean()
                    loss = pde_loss + lambda_bc * bc_loss
                    loss.backward()
                    return loss
                
                loss = optimizer.step(closure)
                
                # Record stats
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                pbar.set_postfix(loss=f"{loss.item():.2e}", mse=f"{mse.item():.2e}")
        
        stat.plot(config, method="pinn")
        return model, stat
    
    else:
        raise ValueError(f"Unknown optimizer {config.optimizer}. Available: adam, adam+lbfgs")


def train_vpinn(data: PoissonData, config: TrainConfig = TrainConfig()) -> tuple:
    """
    Train using Variational PINN (VPINN).
    
    VPINN combines:
    - Autograd for gradient computation (like PINN)
    - Weak form / variational formulation (like GGL)
    - Only requires 1st order derivatives (unlike PINN's 2nd order)
    
    Returns:
        (model, stat): Trained model and training statistics
    """
    if config.model == "sign":
        raise ValueError("SIGN model is not compatible with VPINN. "
                        "VPINN requires coordinate-based models (mlp, mlp2, resmlp) for auto-differentiation. "
                        "Use --method ggl for SIGN model.")
    
    if config.model == "mock":
        model = Mock(data.graph.num_nodes)
    elif config.model == "mlp":
        model = MLP(in_channels=3)
        model.set_f(data.f)
    elif config.model == "mlp2":
        model = MLP(in_channels=2)
    elif config.model == "resmlp":
        model = ResidualMLP()
    else:
        raise ValueError(f"Unknown model {config.model}. Available: mock, mlp, mlp2, resmlp")
    
    bd_mask = data.equation.boundary_mask
    lambda_bc = 10.0  # Boundary condition weight
    
    stat = Stat(data)
    
    if config.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="VPINN[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                # Compute VPINN residual (variational form with autograd)
                residual, u_pred = compute_vpinn_residual(model, data)
                
                # Variational loss
                var_loss = (residual ** 2).mean()
                
                # Boundary condition loss (soft constraint)
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                
                # Total loss
                loss = var_loss + lambda_bc * bc_loss
                
                loss.backward()
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(model)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                
                # Compute MSE with ground truth
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}",
                    var=f"{var_loss.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}", 
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        stat.plot(config, method="vpinn")
        return model, stat
    
    elif config.optimizer == "adam+lbfgs":
        # Phase 1: Adam
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="VPINN[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                residual, u_pred = compute_vpinn_residual(model, data)
                var_loss = (residual ** 2).mean()
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                loss = var_loss + lambda_bc * bc_loss
                
                loss.backward()
                grad_norm = compute_grad_norm(model)
                
                optimizer.step()
                scheduler.step()
                
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}",
                    var=f"{var_loss.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}",
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        # Phase 2: L-BFGS
        optimizer = torch.optim.LBFGS(
            model.parameters(), 
            lr=1.0, 
            max_iter=20,
            history_size=50,
            line_search_fn="strong_wolfe"
        )

        with tqdm(range(config.lbfgs_epochs), desc="VPINN[LBFGS]") as pbar:
            for _ in pbar:
                def closure():
                    optimizer.zero_grad()
                    residual, u_pred = compute_vpinn_residual(model, data)
                    var_loss = (residual ** 2).mean()
                    bc_loss = (u_pred[bd_mask] ** 2).mean()
                    loss = var_loss + lambda_bc * bc_loss
                    loss.backward()
                    return loss
                
                loss = optimizer.step(closure)
                
                # Record stats
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                pbar.set_postfix(loss=f"{loss.item():.2e}", mse=f"{mse.item():.2e}")
        
        stat.plot(config, method="vpinn")
        return model, stat
    
    else:
        raise ValueError(f"Unknown optimizer {config.optimizer}. Available: adam, adam+lbfgs")


def train_deepritz(data: PoissonData, config: TrainConfig = TrainConfig()) -> tuple:
    """
    Train using Deep Ritz Method.
    
    Deep Ritz directly minimizes the energy functional:
    J(u) = (1/2) ∫|∇u|² dx - ∫f·u dx
    
    Instead of minimizing residuals like PINN/VPINN.
    
    Note: The energy can be negative at the optimal solution (this is correct).
    BC weight needs to be larger than PINN/VPINN to prevent the network from
    "cheating" by making u large in the interior while ignoring boundary conditions.
    
    Returns:
        (model, stat): Trained model and training statistics
    """
    if config.model == "sign":
        raise ValueError("SIGN model is not compatible with Deep Ritz. "
                        "Deep Ritz requires coordinate-based models (mlp, mlp2, resmlp) for auto-differentiation. "
                        "Use --method ggl for SIGN model.")
    
    if config.model == "mock":
        model = Mock(data.graph.num_nodes)
    elif config.model == "mlp":
        model = MLP(in_channels=3)
        model.set_f(data.f)
    elif config.model == "mlp2":
        model = MLP(in_channels=2)
    elif config.model == "resmlp":
        model = ResidualMLP()
    else:
        raise ValueError(f"Unknown model {config.model}. Available: mock, mlp, mlp2, resmlp")
    
    bd_mask = data.equation.boundary_mask
    # Deep Ritz needs larger BC weight to prevent "cheating" (making u large to minimize energy)
    lambda_bc = 500.0  # Boundary condition weight (larger than PINN/VPINN)
    
    stat = Stat(data)
    
    if config.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="DeepRitz[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                # Compute Deep Ritz energy
                energy, u_pred = compute_deepritz_energy(model, data)
                
                # Boundary condition loss (soft constraint)
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                
                # Total loss
                loss = energy + lambda_bc * bc_loss
                
                loss.backward()
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(model)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                
                # Compute MSE with ground truth
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}",
                    energy=f"{energy.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}", 
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        stat.plot(config, method="deepritz")
        return model, stat
    
    elif config.optimizer == "adam+lbfgs":
        # Phase 1: Adam
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs, eta_min=config.lr_min
        )
        
        with tqdm(range(config.epochs), desc="DeepRitz[Adam]") as pbar:
            for i in pbar:
                optimizer.zero_grad()
                
                energy, u_pred = compute_deepritz_energy(model, data)
                bc_loss = (u_pred[bd_mask] ** 2).mean()
                loss = energy + lambda_bc * bc_loss
                
                loss.backward()
                grad_norm = compute_grad_norm(model)
                
                optimizer.step()
                scheduler.step()
                
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                current_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix(
                    loss=f"{loss.item():.2e}",
                    energy=f"{energy.item():.2e}",
                    bc=f"{bc_loss.item():.2e}",
                    mse=f"{mse.item():.2e}",
                    lr=f"{current_lr:.2e}",
                    grad=f"{grad_norm:.2e}"
                )
        
        # Phase 2: L-BFGS
        optimizer = torch.optim.LBFGS(
            model.parameters(), 
            lr=1.0, 
            max_iter=20,
            history_size=50,
            line_search_fn="strong_wolfe"
        )

        with tqdm(range(config.lbfgs_epochs), desc="DeepRitz[LBFGS]") as pbar:
            for _ in pbar:
                def closure():
                    optimizer.zero_grad()
                    energy, u_pred = compute_deepritz_energy(model, data)
                    bc_loss = (u_pred[bd_mask] ** 2).mean()
                    loss = energy + lambda_bc * bc_loss
                    loss.backward()
                    return loss
                
                loss = optimizer.step(closure)
                
                # Record stats
                with torch.no_grad():
                    u_eval = model(data.graph)
                    u_bc = u_eval.clone()
                    u_bc[bd_mask] = 0.0
                    mse = nn.MSELoss()(u_bc, data.u)
                
                stat.record(loss.item(), mse.item(), u_bc.detach())
                pbar.set_postfix(loss=f"{loss.item():.2e}", mse=f"{mse.item():.2e}")
        
        stat.plot(config, method="deepritz")
        return model, stat
    
    else:
        raise ValueError(f"Unknown optimizer {config.optimizer}. Available: adam, adam+lbfgs")


def create_model(model_type: str, num_nodes: int = None, 
                 f: torch.Tensor = None, edge_index: torch.Tensor = None,
                 coords: torch.Tensor = None) -> nn.Module:
    """Create a model instance based on type.
    
    Args:
        model_type: Model type (mock, mlp, mlp2, resmlp, sign)
        num_nodes: Number of nodes (required for mock model)
        f: Source function tensor (required for sign and mlp models)
        edge_index: Graph edge index (required for sign model)
        coords: Node coordinates [n_nodes, 2] (required for sign model)
    """
    if model_type == "mock":
        if num_nodes is None:
            raise ValueError("num_nodes required for Mock model")
        return Mock(num_nodes)
    elif model_type == "mlp":
        model = MLP(in_channels=3)  # Use [x, y, f] as input
        if f is not None:
            model.set_f(f)
        return model
    elif model_type == "mlp2":
        return MLP(in_channels=2)  # Use [x, y] only (original behavior)
    elif model_type == "resmlp":
        return ResidualMLP()
    elif model_type == "sign":
        model = SIGNWrapper(num_hidden=64, num_layers=3, num_hops=3)
        if f is not None and edge_index is not None and coords is not None:
            model.set_data(f, edge_index, coords)
        return model
    else:
        raise ValueError(f"Unknown model type: {model_type}. Available: mock, mlp, mlp2, resmlp, sign")


def get_model_cache_path(config: TrainConfig, method: str = "ggl") -> str:
    """Get cache file path for a trained model."""
    cache_dir = "output/cache"
    os.makedirs(cache_dir, exist_ok=True)
    return f"{cache_dir}/{method}_{config}.pth"


def load_or_train(
    data: PoissonData, 
    config: TrainConfig, 
    method: str = "ggl",
    rerun: bool = False
) -> tuple:
    """
    Load cached model or train a new one.
    
    Args:
        data: PoissonData instance
        config: Training configuration
        method: Training method ("ggl" or "pinn")
        rerun: If True, ignore cache and retrain
        
    Returns:
        (model, stat): Trained model and training statistics
    """
    cache_path = get_model_cache_path(config, method)
    stat_cache_path = cache_path.replace('.pth', '_stat.npz')
    
    if os.path.exists(cache_path) and os.path.exists(stat_cache_path) and not rerun:
        print(f"Loading cached {method.upper()} model from {cache_path}")
        model = create_model(config.model, data.graph.num_nodes, 
                            f=data.f, edge_index=data.graph.edge_index,
                            coords=data.graph.x)
        model.load_state_dict(torch.load(cache_path, weights_only=True))
        
        # Load stat
        saved_stat = np.load(stat_cache_path, allow_pickle=True)
        stat = Stat(data)
        stat.losses = saved_stat['losses'].tolist()
        stat.mse = saved_stat['mse'].tolist()
        # Convert each element to float32 numpy array
        stat.us = [np.array(u, dtype=np.float32) for u in saved_stat['us']]
        
        print("Model and stat loaded successfully!")
        
        # Plot the saved training curves
        stat.plot(config, method=method)
        
        return model, stat
    else:
        if rerun and os.path.exists(cache_path):
            print(f"Rerun requested, ignoring cache at {cache_path}")
        
        # Train model based on method
        if method == "ggl":
            model, stat = train_ggl(data, config)
        elif method == "pinn":
            model, stat = train_pinn(data, config)
        elif method == "vpinn":
            model, stat = train_vpinn(data, config)
        elif method == "deepritz":
            model, stat = train_deepritz(data, config)
        else:
            raise ValueError(f"Unknown method: {method}. Available: ggl, pinn, vpinn, deepritz")
        
        # Save model and stat
        torch.save(model.state_dict(), cache_path)
        np.savez(stat_cache_path, 
                 losses=np.array(stat.losses),
                 mse=np.array(stat.mse),
                 us=np.array(stat.us, dtype=object))
        print(f"Model saved to {cache_path}")
        print(f"Stat saved to {stat_cache_path}")
        
        return model, stat


if __name__ == "__main__":
    """
    Example configurations:
    K: 1 -> config = TrainConfig(nx=100, optimizer="adam+lbfgs", K=1, lr=1e-4, lr_min=1e-5, epochs=10000, model="mlp", r=.5)
    K: 2 -> config = TrainConfig(nx=100, optimizer="adam+lbfgs", K=2, lr=1e-4, lr_min=1e-5, epochs=10000, model="mlp", r=.5)
    K: 4 -> config = TrainConfig(nx=100, optimizer="adam+lbfgs", K=4, lr=1e-4, lr_min=1e-5, epochs=10000, model="mlp", r=.5)
    K: 8 -> config = TrainConfig(nx=100, optimizer="adam+lbfgs", K=8, lr=1e-3, lr_min=1e-5, epochs=10000, model="mlp", r=.5)
    
    Usage:
        python test_poisson.py                            # Train GGL (default)
        python test_poisson.py --method pinn              # Train PINN
        python test_poisson.py --method vpinn             # Train VPINN (Variational PINN)
        python test_poisson.py --rerun                    # Force retrain, ignore cache
        python test_poisson.py --landscape                # Train + plot loss landscape
        python test_poisson.py -K 8 --landscape           # Train with K=8 + plot landscape
        python test_poisson.py --method ggl --landscape   # GGL + landscape
        python test_poisson.py --method pinn --landscape  # PINN + landscape
        python test_poisson.py --method vpinn --landscape # VPINN + landscape
        python test_poisson.py --method deepritz          # Deep Ritz (energy minimization)
    """
    import argparse
    
    parser = argparse.ArgumentParser(description="Train GGL/PINN and optionally plot loss landscape")
    parser.add_argument("--method", type=str, default="ggl", choices=["ggl", "pinn", "vpinn", "deepritz"], 
                        help="Training method: ggl, pinn, vpinn, or deepritz (Deep Ritz energy minimization)")
    parser.add_argument("--landscape", action="store_true", help="Compute and plot loss landscape after training")
    parser.add_argument("--rerun", "-r", action="store_true", help="Ignore cache and retrain model")
    parser.add_argument("--optimizer", type=str, default="adam+lbfgs", choices=["adam", "adam+lbfgs"], 
                        help="Optimizer: adam (default) or adam+lbfgs (for deepritz)")
    parser.add_argument("--n_points", type=int, default=21, help="Grid resolution for landscape (default: 21)")
    parser.add_argument("--scale", type=float, default=1.0, help="Scan radius for landscape (default: 1.0)")
    parser.add_argument("-K", type=int, default=4, help="Fourier mode number (default: 4)")
    parser.add_argument("--epochs", "-ep",type=int, default=10000, help="Adam training epochs (default: 10000)")
    parser.add_argument("--lbfgs_epochs", type=int, default=200, help="L-BFGS training epochs (default: 1000)")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate (default: 1e-4)")
    parser.add_argument("--lr_min", type=float, default=1e-5, help="Minimum learning rate (default: 1e-5)")
    parser.add_argument("--model", type=str, default="mlp2", choices=["mock", "mlp", "mlp2", "resmlp", "sign"], 
                        help="Model type: mlp=[x,y,f], mlp2=[x,y] only (sign only works with ggl method)")
    parser.add_argument("--source", type=str, default="analytical", 
                        choices=["analytical", "abs_sin", "step", "indicator_circle", "checkerboard", "constant"],
                        help="Source function type: analytical (smooth), or discontinuous types")
    parser.add_argument("--domain", type=str, default="rectangle",
                        choices=["rectangle", "L_shape", "circle"],
                        help="Domain shape: rectangle, L_shape (has corner singularity), circle")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()
    
    # Set random seed
    SEED = args.seed
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Configure training
    config = TrainConfig(
        nx=100, 
        optimizer=args.optimizer, 
        K=args.K, 
        lr=args.lr, 
        lr_min=args.lr_min, 
        epochs=args.epochs,
        lbfgs_epochs=args.lbfgs_epochs,
        model=args.model,
        source_type=args.source,
        domain=args.domain,
        r=0.5
    )
    
    print(f"Method: {args.method.upper()}")
    print(f"Config: {config}")
    
    # Load or train model
    data = PoissonData(config)
    print(f"DEBUG: f checksum = {data.f.sum().item():.6f}")

    model, stat = load_or_train(data, config, method=args.method, rerun=args.rerun)
    
    # Evaluate model and print L2 error
    print("\n" + "=" * 50)
    print("Model Evaluation")
    print("=" * 50)
    with torch.no_grad():
        u_pred = model(data.graph)
        u_pred_bc = u_pred.clone()
        
        # GGL uses hard constraint (post-processing), PINN/VPINN use soft constraint
        if args.method == "ggl":
            u_pred_bc[data.equation.boundary_mask] = 0.0  # Hard constraint for GGL
            print(f"Boundary handling: Hard constraint (post-processing)")
        else:
            # For PINN/VPINN: use raw model output to evaluate soft constraint quality
            print(f"Boundary handling: Soft constraint (model learned)")
            bd_error = torch.abs(u_pred[data.equation.boundary_mask]).mean().item()
            print(f"Boundary Error (mean |u| on boundary): {bd_error:.6e}")
        
        l2_error = compute_l2_error(u_pred_bc, data.u) * 100
        mse = nn.MSELoss()(u_pred_bc, data.u).item()
        
        print(f"Relative L2 Error: {l2_error:.2f}%")
        print(f"MSE: {mse:.6e}")
        print(f"Number of nodes: {data.graph.num_nodes}")
    print("=" * 50 + "\n")
    
    # Optionally compute loss landscape
    if args.landscape:
        if args.method == "ggl":
            compute_galerkin_loss_landscape(
                model=model,
                data=data,
                config=config,
                n_points=args.n_points,
                scale=args.scale
            )
        elif args.method == "pinn":
            compute_pinn_loss_landscape(
                model=model,
                data=data,
                config=config,
                n_points=args.n_points,
                scale=args.scale
            )
        elif args.method == "vpinn":
            compute_vpinn_loss_landscape(
                model=model,
                data=data,
                config=config,
                n_points=args.n_points,
                scale=args.scale
            )