"""
Wave equation data generators for TensorGalerkin

Solves the wave equation:
    ∂²u/∂t² = c² * Δu  in Ω
    u = g              on ∂Ω (Dirichlet boundary)
    u(x, 0) = u0       (initial condition)
    ∂u/∂t(x, 0) = v0   (initial velocity)
"""

import numpy as np
import scipy.sparse
import scipy.sparse.linalg
import torch
import torch.nn as nn
from typing import Tuple, Union, List
from tqdm import tqdm

from ...discretization import tri3, quad4, tri_gauss_points, quad_gauss_points


def _uniform(min_val: float, max_val: float, shape) -> torch.Tensor:
    """Generate uniform random tensor in [min_val, max_val]"""
    return torch.rand(shape) * (max_val - min_val) + min_val


class WaveGen:
    """Data generators for wave equation problems"""
    
    class MultiAnalytical:
        """Multi-mode analytical solutions for wave equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, 
                            a: np.ndarray, 
                            r: float = 0.5) -> Tuple[np.ndarray, np.ndarray]:
            """
            Generate the wave equation initial condition at each point.
            
            The initial condition is a superposition of sinusoidal modes:
            u0(x,y) = π/K² * Σ_{i,j} (a_{ij} * (i²+j²)^{-r} * sin(πix) * sin(πjy))
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                a: np.ndarray (K, K) or (N, K, K)
                    Coefficient matrix for the wave equation
                r: float
                    Decay exponent for frequency modes (default: 0.5)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,) or (N, n_points)
                    Initial displacement
                v0: np.ndarray (n_points,) or (N, n_points)
                    Initial velocity (ones, not zeros for non-trivial solution)
            """
            K = a.shape[-1]
            j, i = np.meshgrid(np.arange(1, K + 1), np.arange(1, K + 1))  # (K, K)
            
            if len(a.shape) == 2:
                a = a[None, :, :]  # (1, K, K)
                i, j = i[None, :, :], j[None, :, :]  # (1, K, K)
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]  # (n_points, 1, 1)
            else:
                a = a[:, None, :, :]  # (N, 1, K, K)
                i, j = i[None, None, :, :], j[None, None, :, :]  # (1, 1, K, K)
                x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]  # (1, n_points, 1, 1)
            
            u0 = (np.pi / K / K * (a * (i * i + j * j)**(-r) * 
                  np.sin(np.pi * i * x) * np.sin(np.pi * j * y))).sum((-2, -1))
            # Note: v0 is ones for non-zero initial velocity; for true analytical, use zeros
            v0 = np.ones(u0.shape)
            return u0, v0
        
        @staticmethod
        def solution(points: np.ndarray, 
                    a: np.ndarray, 
                    r: float = 0.5, 
                    c: float = 1.0, 
                    t: float = 0.1) -> np.ndarray:
            """
            Generate the analytical solution for wave equation at time t.
            
            The solution is:
            u(x,y,t) = π/K² * Σ_{i,j} (a_{ij} * (i²+j²)^{-r} * sin(πix) * sin(πjy) 
                                       * cos(cπt√(i²+j²)))
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                a: np.ndarray (K, K) or (N, K, K)
                    Coefficient matrix for the wave equation
                r: float
                    Decay exponent for frequency modes (default: 0.5)
                c: float
                    Wave speed (default: 1.0)
                t: float
                    Time at which to evaluate the solution (default: 0.1)
                    
            Returns:
            --------
                ut: np.ndarray (n_points,) or (N, n_points)
                    Solution values at time t
            """
            K = a.shape[-1]
            j, i = np.meshgrid(np.arange(1, K + 1), np.arange(1, K + 1))  # (K, K)
            
            if len(a.shape) == 2:
                a = a[None, :, :]  # (1, K, K)
                i, j = i[None, :, :], j[None, :, :]  # (1, K, K)
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]  # (n_points, 1, 1)
            else:
                a = a[:, None, :, :]  # (N, 1, K, K)
                i, j = i[None, None, :, :], j[None, None, :, :]  # (1, 1, K, K)
                x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]  # (1, n_points, 1, 1)
            
            u0 = (np.pi / K / K * (a * (i * i + j * j)**(-r) * 
                  np.sin(np.pi * i * x) * np.sin(np.pi * j * y) * 
                  np.cos(c * np.pi * t * np.sqrt(i * i + j * j)))).sum((-2, -1))
            return u0
    
    class Random:
        """Random data generators for wave equation"""
        
        @staticmethod
        def initial_condition_random(points: np.ndarray, 
                                    low: float = -1.0, 
                                    high: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
            """
            Generate random initial condition for wave equation.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                low: float
                    Lower bound of uniform distribution (default: -1.0)
                high: float
                    Upper bound of uniform distribution (default: 1.0)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Random initial displacement
                v0: np.ndarray (n_points,)
                    Random initial velocity
            """
            u = np.random.uniform(low=low, high=high, size=(points.shape[0],))
            v = np.random.uniform(low=low, high=high, size=(points.shape[0],))
            return u, v
        
        @staticmethod
        def initial_condition_single_gaussian(points: np.ndarray, 
                                             sig: float = 0.2) -> Tuple[np.ndarray, np.ndarray]:
            """
            Generate Gaussian initial condition centered in the domain.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                sig: float
                    Standard deviation of Gaussian (default: 0.2)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Gaussian initial displacement
                v0: np.ndarray (n_points,)
                    Zero initial velocity
            """
            center_of_domain = (points.max(0) + points.min(0)) / 2
            u = np.exp(-np.linalg.norm(points - center_of_domain, axis=1) / sig)
            v = np.zeros_like(u)
            return u, v
        
        @staticmethod
        def initial_condition_gaussian(points: np.ndarray, 
                                       num_centers: Tuple[int, int] = (2, 6)) -> Tuple[np.ndarray, np.ndarray]:
            """
            Generate multi-center Gaussian initial condition.
            
            Generates random Gaussian bumps at non-overlapping centers.
            Reference: https://arxiv.org/abs/2405.19101, page 30
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                num_centers: Tuple[int, int]
                    Range for number of Gaussian centers (default: (2, 6))
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Multi-Gaussian initial displacement
                v0: np.ndarray (n_points,)
                    Zero initial velocity
            """
            u0 = np.zeros_like(points[:, 0])
            n = np.random.randint(num_centers[0], num_centers[1])
            centers = []
            
            for _ in range(n):
                while True:
                    xc = np.random.uniform(1/6, 5/6)
                    yc = np.random.uniform(1/6, 5/6)
                    s = np.random.uniform(0.039, 0.156)
                    valid = True
                    for (xc_i, yc_i, s_i) in centers:
                        if np.sqrt((xc - xc_i)**2 + (yc - yc_i)**2) < 2 * s_i:
                            valid = False
                            break
                    if valid:
                        centers.append((xc, yc, s))
                        break
            
            for xc, yc, s in centers:
                u0 += np.exp(-((points[:, 0] - xc)**2 + (points[:, 1] - yc)**2) / (2 * s**2))
            
            v0 = np.zeros(u0.shape)
            return u0, v0
        
        @staticmethod
        def initial_propagation_speed_gaussian(points: np.ndarray,
                                              c0_lim: Tuple[float, float] = (1000, 1500),
                                              v_lim: Tuple[float, float] = (1000, 1500)) -> np.ndarray:
            """
            Generate spatially varying wave speed with Gaussian perturbations.
            
            Creates a base wave speed c0 plus Gaussian perturbations centered
            at four anchor points.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                c0_lim: Tuple[float, float]
                    Range for base wave speed (default: (1000, 1500))
                v_lim: Tuple[float, float]
                    Range for Gaussian amplitude (default: (1000, 1500))
                    
            Returns:
            --------
                c: np.ndarray (n_points,)
                    Spatially varying wave speed
            """
            c0 = np.random.uniform(*c0_lim)
            anchors = np.array([(0.25, 0.25), (0.25, 0.75), (0.75, 0.25), (0.75, 0.75)])
            c = np.ones_like(points[:, 0]) * c0
            
            for (xi, yi) in anchors:
                dxi = np.random.uniform(-0.3125, 0.3125)
                dyi = np.random.uniform(-0.3125, 0.3125)
                vi = np.random.uniform(*v_lim)
                sigma_i = np.random.uniform(1/12, 1/6)
                
                c += vi * np.exp(-((points[:, 0] - (xi + dxi))**2 + 
                                   (points[:, 1] - (yi + dyi))**2) / (2 * sigma_i**2))
            
            return c
        
        @staticmethod
        def solution(mesh, 
                    u0: np.ndarray, 
                    v0: np.ndarray, 
                    c: Union[float, np.ndarray], 
                    T: float = 1.0, 
                    dt: float = 0.01, 
                    recording: bool = False, 
                    verbose: bool = False) -> Union[np.ndarray, List[np.ndarray]]:
            """
            Solve wave equation using finite element method with Newmark scheme.
            
            Solves: ∂²u/∂t² = c² * Δu
            Using explicit central difference time integration.
            
            Parameters:
            -----------
                mesh: meshio.Mesh
                    The computational mesh with boundary conditions.
                    Must have 'boundary_mask' and 'boundary_value' in point_data.
                u0: np.ndarray (n_points,)
                    Initial displacement values at mesh nodes
                v0: np.ndarray (n_points,)
                    Initial velocity values at mesh nodes
                c: float or np.ndarray (n_points,)
                    Wave speed (constant or spatially varying)
                T: float
                    Final time (default: 1.0)
                dt: float
                    Time step size (default: 0.01)
                recording: bool
                    If True, return solution at all time steps (default: False)
                verbose: bool
                    If True, show progress bar (default: False)
                    
            Returns:
            --------
                If recording=True:
                    List[np.ndarray]: Solution at each time step [u0, u1, ..., uN]
                Else:
                    np.ndarray (n_points,): Solution at final time T
            """
            # Validate mesh data
            assert "boundary_mask" in mesh.point_data, "Mesh must have boundary_mask in point_data"
            assert "boundary_value" in mesh.point_data, "Mesh must have boundary_value in point_data"
            
            points = mesh.points
            
            # Determine element type and setup quadrature
            if "triangle" in mesh.cells_dict:
                elements = mesh.cells_dict['triangle']
                qpoints = tri_gauss_points(ngp=1)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            elif "quad" in mesh.cells_dict:
                elements = mesh.cells_dict['quad']
                qpoints = quad_gauss_points()
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            else:
                raise NotImplementedError("Mesh should have triangle or quadrilateral cells")
            
            boundary_mask = mesh.point_data['boundary_mask']
            boundary_value = mesh.point_data['boundary_value']
            num_points = points.shape[0]
            num_elements, num_basis = elements.shape
            
            # Create element-to-edge mapping for matrix assembly
            elem_u, elem_v = [], []
            for i in range(num_basis):
                for j in range(num_basis):
                    elem_u.append(elements[:, i])
                    elem_v.append(elements[:, j])
            elem_u = np.stack(elem_u, -1).ravel()
            elem_v = np.stack(elem_v, -1).ravel()
            
            # Remove duplicated edges
            tmp = scipy.sparse.coo_matrix(
                (np.ones_like(elem_u), (elem_u, elem_v)),
                shape=(num_points, num_points)
            ).tocsr().tocoo()
            edge_u, edge_v = tmp.row, tmp.col
            num_edges = len(edge_u)
            
            eids_csr = scipy.sparse.coo_matrix(
                (np.arange(num_edges), (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            elem_eids = np.array(eids_csr[elem_u, elem_v]).ravel()
            ele2msh_edge = scipy.sparse.coo_matrix(
                (np.ones_like(elem_eids), (elem_eids, np.arange(num_elements * num_basis * num_basis))),
                shape=(num_edges, num_elements * num_basis * num_basis)
            ).astype(np.float64)
            
            # Compute Jacobian weighted quadrature weights
            jac_det = np.abs(jac_det)  # [num_elements, num_quadrature_points]
            JxW = jac_det * quadrature_weight  # [num_elements, num_quadrature_points]
            
            # Assemble stiffness matrix K (with wave speed)
            if isinstance(c, np.ndarray) and c.shape[0] == u0.shape[0]:
                # Spatially varying wave speed
                c_elem = c[elements]  # [n_element, num_basis]
                c_quad = np.einsum('gb,eb->eg', shape_val, c_elem)  # [n_element, num_quadrature_points]
                K_elem = np.einsum("eq,eq,eqib,eqjb,eq->eqij", c_quad, c_quad, shape_grad, shape_grad, JxW)
            elif isinstance(c, (float, int)):
                # Constant wave speed
                K_elem = c * c * np.einsum("eqib,eqjb,eq->eqij", shape_grad, shape_grad, JxW)
            else:
                raise NotImplementedError(f"Wave speed should be float or np.ndarray of shape [n_node], got {type(c)}")
            
            # Assemble mass matrix M
            M_elem = np.einsum("qi,qj,eq->eqij", shape_val, shape_val, JxW)
            
            K_elem = K_elem.sum(1)  # [num_elements, num_basis, num_basis]
            M_elem = M_elem.sum(1)  # [num_elements, num_basis, num_basis]
            K = ele2msh_edge @ K_elem.ravel()  # [num_edges]
            M = ele2msh_edge @ M_elem.ravel()  # [num_edges]
            
            K_global = scipy.sparse.coo_matrix(
                (K, (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            M_global = scipy.sparse.coo_matrix(
                (M, (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            # Apply boundary conditions using static condensation
            is_inner_node = ~boundary_mask
            is_outer_node = boundary_mask
            is_inner_u = is_inner_node[edge_u]
            is_inner_v = is_inner_node[edge_v]
            is_inner_edge = is_inner_u & is_inner_v
            is_ou2in_edge = is_inner_u & is_outer_node[edge_v]
            
            n_inner_nodes = is_inner_node.sum()
            n_outer_nodes = is_outer_node.sum()
            
            # Create local node IDs
            local_nids = np.full((num_points,), -1, dtype=np.int64)
            local_nids[is_inner_node] = np.arange(n_inner_nodes)
            local_nids[is_outer_node] = np.arange(n_outer_nodes)
            
            # Extract matrix blocks
            M_inner = scipy.sparse.coo_matrix(
                (M[is_inner_edge], (local_nids[edge_u[is_inner_edge]], local_nids[edge_v[is_inner_edge]])),
                shape=(n_inner_nodes, n_inner_nodes)
            ).tocsr()
            
            M_ou2in = scipy.sparse.coo_matrix(
                (M[is_ou2in_edge], (local_nids[edge_u[is_ou2in_edge]], local_nids[edge_v[is_ou2in_edge]])),
                shape=(n_inner_nodes, n_outer_nodes)
            ).tocsr()
            
            # Initialize solution
            u0[is_outer_node] = boundary_value[is_outer_node]
            
            # Compute u1 using initial velocity
            F = (-dt * dt * K_global @ u0 + 2 * M_global @ u0 + 2 * dt * M_global @ v0) / 2
            F = F[is_inner_node] - M_ou2in @ boundary_value[is_outer_node]
            
            U = np.zeros((num_points,))
            U[is_inner_node] = scipy.sparse.linalg.spsolve(M_inner, F)
            U[is_outer_node] = boundary_value[is_outer_node]
            U_ = u0.copy()
            
            if recording:
                Us = [U_, U]
            
            # Time stepping
            if recording and verbose:
                iterator = tqdm(np.arange(dt, T + dt / 10, dt), desc="Wave equation")
            else:
                iterator = np.arange(dt, T + dt / 10, dt)
            
            for t_ in iterator:
                if recording:
                    U = Us[-1]
                    U_ = Us[-2]
                
                # Central difference scheme: M u^{n+1} = 2M u^n - M u^{n-1} - dt² K u^n
                F = 2 * M_global @ U - M_global @ U_ - dt * dt * K_global @ U
                F = F[is_inner_node] - M_ou2in @ boundary_value[is_outer_node]
                
                # Solve linear system
                u = scipy.sparse.linalg.spsolve(M_inner, F)
                
                if recording:
                    U_new = np.zeros((num_points,))
                    U_new[is_outer_node] = boundary_value[is_outer_node]
                    U_new[is_inner_node] = u
                    Us.append(U_new)
                else:
                    U_ = U.copy()
                    U[is_outer_node] = boundary_value[is_outer_node]
                    U[is_inner_node] = u
            
            if recording:
                return Us
            else:
                return U


# ============================================================================
# Parametric PyTorch Modules for Wave Equation
# ============================================================================

class ParametricMultiAnalytical(nn.Module):
    """
    Parametric multi-mode analytical solution generator for wave equation.
    
    Supports learnable wave speed c, coefficient matrix a, and decay exponent r.
    
    Attributes:
        c_lim: Tuple or float for wave speed bounds
        r_lim: Tuple or float for decay exponent bounds
        K: Number of Fourier modes
    """
    c: torch.Tensor
    a: torch.Tensor
    r: torch.Tensor
    
    def __init__(self,
                 c_lim: Union[Tuple[float, float], float] = (1.0, 2.0),
                 r_lim: Union[Tuple[float, float], float] = 0.5,
                 K: int = 4,
                 is_c_parameter: bool = True,
                 is_a_parameter: bool = False,
                 is_r_parameter: bool = False):
        """
        Initialize parametric analytical solution generator.
        
        Parameters:
        -----------
            c_lim: Wave speed limits (tuple) or fixed value (float)
            r_lim: Decay exponent limits (tuple) or fixed value (float)
            K: Number of Fourier modes
            is_c_parameter: If True, c is learnable
            is_a_parameter: If True, a is learnable
            is_r_parameter: If True, r is learnable
        """
        super().__init__()
        self.c_lim = c_lim
        self.r_lim = r_lim
        self.K = K
        
        if is_c_parameter:
            assert not isinstance(c_lim, (float, int)), \
                f"c_lim should be a tuple when is_c_parameter=True, got {c_lim}"
            self.c = nn.Parameter(_uniform(c_lim[0], c_lim[1], (1,)))
        else:
            c = torch.tensor([c_lim]) if isinstance(c_lim, (float, int)) else _uniform(c_lim[0], c_lim[1], (1,))
            self.register_buffer("c0", c)
        
        a = torch.from_numpy(np.random.uniform(low=-1.0, high=1.0, size=[K, K]))
        if is_a_parameter:
            self.a = nn.Parameter(a)
        else:
            self.register_buffer("a", a)
        
        if is_r_parameter:
            assert not isinstance(r_lim, (float, int)), \
                f"r_lim should be a tuple when is_r_parameter=True, got {r_lim}"
            self.r = nn.Parameter(_uniform(r_lim[0], r_lim[1], (1,)))
        else:
            r = torch.tensor([r_lim]) if isinstance(r_lim, (float, int)) else _uniform(r_lim[0], r_lim[1], (1,))
            self.register_buffer("r", r)
    
    def reset_parameters(self):
        """Reset all learnable parameters to random values within their limits."""
        if isinstance(self.c, nn.Parameter):
            nn.init.uniform_(self.c, self.c_lim[0], self.c_lim[1])
        if isinstance(self.a, nn.Parameter):
            nn.init.uniform_(self.a, -1.0, 1.0)
        if isinstance(self.r, nn.Parameter):
            nn.init.uniform_(self.r, self.r_lim[0], self.r_lim[1])
    
    def initial_condition(self, points: torch.Tensor,
                         K: Tuple[int, int] = (2, 6),
                         r_lim: Tuple[float, float] = (0.1, 0.9)) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate wave initial condition with random parameters.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
                Spatial coordinates
            K: Tuple[int, int]
                Range for number of Fourier modes
            r_lim: Tuple[float, float]
                Range for decay exponent
                
        Returns:
        --------
            u0: torch.Tensor (n_points,)
                Initial displacement
            v0: torch.Tensor (n_points,)
                Initial velocity
        """
        K_val = np.random.randint(K[0], K[1])
        a = torch.from_numpy(np.random.uniform(low=-1.0, high=1.0, size=[K_val, K_val]))
        r = _uniform(r_lim[0], r_lim[1], (1,))
        j, i = torch.meshgrid(torch.arange(1, K_val + 1), torch.arange(1, K_val + 1), indexing='xy')
        i, j = i.to(self.a.device), j.to(self.a.device)
        
        if len(a.shape) == 2:
            a = a[None, :, :].to(self.a.device)
            i, j = i[None, :, :], j[None, :, :]
            x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]
        else:
            a = a[:, None, :, :].to(self.a.device)
            i, j = i[None, None, :, :], j[None, None, :, :]
            x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]
        
        r = r.to(self.a.device)
        u0 = (torch.pi / K_val / K_val * (a * (i * i + j * j)**(-r) * 
              torch.sin(torch.pi * i * x) * torch.sin(torch.pi * j * y))).sum((-2, -1))
        v0 = torch.ones(u0.shape)
        
        u0 = u0.to(self.a.device).type(self.a.dtype)
        v0 = v0.to(self.a.device).type(self.a.dtype)
        
        return u0, v0
    
    def initial_prop_speed(self, points: torch.Tensor) -> torch.Tensor:
        """
        Get constant propagation speed for all points.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
                Spatial coordinates
                
        Returns:
        --------
            c: torch.Tensor (n_points,)
                Wave speed at each point
        """
        n_points = points.shape[0]
        return self.c0.expand(n_points)


class ParametricGaussian(nn.Module):
    """
    Parametric Gaussian perturbation model for spatially varying wave speed.
    
    The wave speed is: c(x) = c0 + Σ_i v_i * exp(-|x - anchor_i|² / 2σ_i²)
    
    Supports learnable base speed c0, anchor offsets, amplitudes v, and widths σ.
    """
    c0: torch.Tensor
    danchors: torch.Tensor
    v: torch.Tensor
    sigma: torch.Tensor
    
    def __init__(self,
                 c0_lim: Tuple[float, float] = (1500, 2500),
                 v_lim: Tuple[float, float] = (1000, 2500),
                 sigma_lim: Tuple[float, float] = (1/12, 1/6),
                 danchor_lim: Tuple[float, float] = (-0.3125, 0.3125),
                 is_c0_parameter: bool = True,
                 is_anchor_parameter: bool = True,
                 is_v_parameter: bool = True,
                 is_sigma_parameter: bool = True):
        """
        Initialize parametric Gaussian wave speed model.
        
        Parameters:
        -----------
            c0_lim: Base wave speed limits
            v_lim: Gaussian amplitude limits
            sigma_lim: Gaussian width limits
            danchor_lim: Anchor offset limits
            is_*_parameter: If True, the corresponding parameter is learnable
        """
        super().__init__()
        
        self.c0_lim = c0_lim
        self.v_lim = v_lim
        self.sigma_lim = sigma_lim
        self.danchor_lim = danchor_lim
        
        if is_c0_parameter:
            self.c0 = nn.Parameter(_uniform(c0_lim[0], c0_lim[1], (1,)))
        else:
            self.register_buffer("c0", _uniform(c0_lim[0], c0_lim[1], (1,)))
        
        danchors = _uniform(danchor_lim[0], danchor_lim[1], self.anchors.shape)
        if is_anchor_parameter:
            self.danchors = nn.Parameter(danchors)
        else:
            self.register_buffer("danchors", danchors)
        
        v = _uniform(v_lim[0], v_lim[1], (len(danchors),))
        if is_v_parameter:
            self.v = nn.Parameter(v)
        else:
            self.register_buffer("v", v)
        
        sigma = _uniform(self.sigma_lim[0], self.sigma_lim[1], (len(danchors),))
        if is_sigma_parameter:
            self.sigma = nn.Parameter(sigma)
        else:
            self.register_buffer("sigma", sigma)
    
    @property
    def anchors(self) -> torch.Tensor:
        """Get anchor positions (base + offsets)."""
        anchors = torch.tensor([
            [0.25, 0.25],
            [0.25, 0.75],
            [0.75, 0.25],
            [0.75, 0.75]
        ])
        if hasattr(self, "danchors"):
            anchors = anchors.to(self.danchors.device).type(self.danchors.dtype) + self.danchors
        return anchors
    
    def reset_parameters(self):
        """Reset all learnable parameters to random values within their limits."""
        if isinstance(self.c0, nn.Parameter):
            nn.init.uniform_(self.c0, self.c0_lim[0], self.c0_lim[1])
        if isinstance(self.danchors, nn.Parameter):
            nn.init.uniform_(self.danchors, self.danchor_lim[0], self.danchor_lim[1])
        if isinstance(self.v, nn.Parameter):
            nn.init.uniform_(self.v, self.v_lim[0], self.v_lim[1])
        if isinstance(self.sigma, nn.Parameter):
            nn.init.uniform_(self.sigma, self.sigma_lim[0], self.sigma_lim[1])
    
    def initial_condition_single_gaussian(self, points: torch.Tensor, 
                                          sig: float = 0.2) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate single Gaussian initial condition centered in domain.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
            sig: Standard deviation
            
        Returns:
        --------
            u0, v0: Initial displacement and velocity
        """
        center_of_domain = (points.max(0).values + points.min(0).values) / 2
        u = torch.exp(-torch.linalg.norm(points - center_of_domain, axis=1) / sig)
        v = torch.zeros_like(u)
        return u, v
    
    def initial_condition(self, points: torch.Tensor,
                         num_centers: Tuple[int, int] = (2, 6),
                         xc_lim: Tuple[float, float] = (1/6, 5/6),
                         yc_lim: Tuple[float, float] = (1/6, 5/6),
                         s_lim: Tuple[float, float] = (0.039, 0.156)) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate multi-center Gaussian initial condition.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
            num_centers: Range for number of Gaussian centers
            xc_lim, yc_lim: Center coordinate limits
            s_lim: Standard deviation limits
            
        Returns:
        --------
            u0, v0: Initial displacement and velocity (don't require grad)
        """
        u0 = torch.zeros_like(points[:, 0])
        v0 = torch.zeros_like(u0)
        n = np.random.randint(num_centers[0], num_centers[1])
        centers = []
        
        for _ in range(n):
            while True:
                xc = np.random.uniform(xc_lim[0], xc_lim[1])
                yc = np.random.uniform(yc_lim[0], yc_lim[1])
                s = np.random.uniform(s_lim[0], s_lim[1])
                valid = True
                for (xc_i, yc_i, s_i) in centers:
                    if np.sqrt((xc - xc_i)**2 + (yc - yc_i)**2) < 2 * s_i:
                        valid = False
                        break
                if valid:
                    centers.append((xc, yc, s))
                    break
        
        for xc, yc, s in centers:
            u0 += torch.exp(-((points[:, 0] - xc)**2 + (points[:, 1] - yc)**2) / (2 * s**2))
        
        return u0, v0
    
    def initial_prop_speed(self, points: torch.Tensor) -> torch.Tensor:
        """
        Compute spatially varying wave speed at given points.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
            
        Returns:
        --------
            c: torch.Tensor (n_points,)
                Wave speed at each point
        """
        c = (self.v[None, :] * torch.exp(
            -((points[:, None, :] - self.anchors[None, :, :])**2).sum(-1) / (2 * self.sigma[None, :]**2)
        )).sum(1) + self.c0
        
        return c


class ParametricLinearLayer(nn.Module):
    """
    Parametric piecewise constant wave speed model (layered medium).
    
    Divides the domain into horizontal layers with different wave speeds.
    
    Attributes:
        c: Wave speed in each layer
        bucket: Layer boundaries
    """
    c: torch.Tensor
    bucket: torch.Tensor
    
    def __init__(self,
                 n_layers: int = 2,
                 c_lim: Tuple[float, float] = (1.0, 2.0),
                 x_lim: Tuple[float, float] = (0.0, 1.0),
                 y_lim: Tuple[float, float] = (0.0, 1.0)):
        """
        Initialize layered wave speed model.
        
        Parameters:
        -----------
            n_layers: Number of horizontal layers
            c_lim: Wave speed limits for each layer
            x_lim: Domain x-coordinate limits
            y_lim: Domain y-coordinate limits
        """
        super().__init__()
        self.c = nn.Parameter(_uniform(c_lim[0], c_lim[1], (n_layers,)))
        self.n_layers = n_layers
        self.c_lim = c_lim
        self.x_lim = x_lim
        self.y_lim = y_lim
        
        delta = 1e-7
        self.register_buffer("bucket", torch.linspace(y_lim[0], y_lim[1] + delta, n_layers + 1))
    
    def reset_parameters(self):
        """Reset wave speeds to random values within limits."""
        nn.init.uniform_(self.c, self.c_lim[0], self.c_lim[1])
    
    def initial_condition(self, points: torch.Tensor,
                         num_centers: Tuple[int, int] = (2, 6),
                         xc_lim: Tuple[float, float] = (1/6, 5/6),
                         yc_lim: Tuple[float, float] = (1/6, 5/6),
                         s_lim: Tuple[float, float] = (0.039, 0.156)) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate multi-center Gaussian initial condition.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
                Must be in [0, 1] range
            num_centers, xc_lim, yc_lim, s_lim: Gaussian parameters
            
        Returns:
        --------
            u0, v0: Initial displacement and velocity
        """
        assert (points >= 0).all() and (points <= 1).all(), "Points should be in [0, 1]"
        
        u0 = torch.zeros_like(points[:, 0])
        v0 = torch.zeros_like(u0)
        n = np.random.randint(num_centers[0], num_centers[1])
        centers = []
        
        for _ in range(n):
            while True:
                xc = np.random.uniform(xc_lim[0], xc_lim[1])
                yc = np.random.uniform(yc_lim[0], yc_lim[1])
                s = np.random.uniform(s_lim[0], s_lim[1])
                valid = True
                for (xc_i, yc_i, s_i) in centers:
                    if np.sqrt((xc - xc_i)**2 + (yc - yc_i)**2) < 2 * s_i:
                        valid = False
                        break
                if valid:
                    centers.append((xc, yc, s))
                    break
        
        for xc, yc, s in centers:
            u0 += torch.exp(-((points[:, 0] - xc)**2 + (points[:, 1] - yc)**2) / (2 * s**2))
        
        return u0, v0
    
    def initial_prop_speed(self, points: torch.Tensor) -> torch.Tensor:
        """
        Compute piecewise constant wave speed based on y-coordinate.
        
        Parameters:
        -----------
            points: torch.Tensor (n_points, 2)
                Must be in [0, 1] range
                
        Returns:
        --------
            c: torch.Tensor (n_points,)
                Wave speed at each point
        """
        assert (points >= 0).all() and (points <= 1).all(), "Points should be in [0, 1]"
        
        layer_idx = torch.bucketize(points[:, 1], self.bucket, right=True)
        return self.c[layer_idx - 1]

