"""
Helmholtz equation data generators for TensorGalerkin

Solves the Helmholtz equation:
    -Δu - k²u = f  in Ω
    u = 0          on ∂Ω (Dirichlet boundary)

Where k is the wavenumber which can be constant or spatially varying.
"""

import numpy as np
import scipy.sparse
import scipy.sparse.linalg
from typing import Union, Callable, List, Optional

from ...discretization import tri3, quad4, tri_gauss_points, quad_gauss_points


class HelmholtzGen:
    """Data generators for Helmholtz equation problems"""
    
    class WavenumberGenerator:
        """Generate wavenumber fields k(x) for Helmholtz equation"""
        
        @staticmethod
        def constant(points: np.ndarray, k: float = 1.0) -> np.ndarray:
            """
            Generate constant wavenumber field.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                k: float
                    Constant wavenumber value (default: 1.0)
                    
            Returns:
            --------
                k_field: np.ndarray (n_points,)
                    Wavenumber at each node
            """
            return np.full(points.shape[0], k)
        
        @staticmethod
        def gaussian(points: np.ndarray, 
                    k0: float, 
                    cx: float = 0.5, 
                    cy: float = 0.5,
                    sigma: float = 0.2, 
                    k_base: float = 0.0) -> np.ndarray:
            """
            Generate Gaussian-shaped wavenumber field.
            
            k(x) = k_base + k0 * exp(-|x - c|² / (2σ²))
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                k0: float
                    Peak wavenumber amplitude
                cx, cy: float
                    Center of Gaussian (default: 0.5, 0.5)
                sigma: float
                    Standard deviation of Gaussian (default: 0.2)
                k_base: float
                    Base wavenumber (added to Gaussian, default: 0.0)
                    
            Returns:
            --------
                k_field: np.ndarray (n_points,)
                    Wavenumber at each node
            """
            x, y = points[:, 0], points[:, 1]
            r2 = (x - cx) ** 2 + (y - cy) ** 2
            return k_base + k0 * np.exp(-r2 / (2 * sigma ** 2))
        
        @staticmethod
        def sinusoidal(points: np.ndarray, 
                      k0: float = 1.0, 
                      kx: int = 1, 
                      ky: int = 1, 
                      k_base: float = 1.0) -> np.ndarray:
            """
            Generate sinusoidal wavenumber field.
            
            k(x) = k_base + k0 * sin(πkx*x) * sin(πky*y)
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                k0: float
                    Amplitude of sinusoidal variation (default: 1.0)
                kx, ky: int
                    Number of half-periods in x and y directions (default: 1, 1)
                k_base: float
                    Base wavenumber (default: 1.0)
                    
            Returns:
            --------
                k_field: np.ndarray (n_points,)
                    Wavenumber at each node
            """
            x, y = points[:, 0], points[:, 1]
            return k_base + k0 * np.sin(np.pi * kx * x) * np.sin(np.pi * ky * y)
        
        @staticmethod
        def random_field(points: np.ndarray, 
                        k_min: float = 0.5, 
                        k_max: float = 2.0, 
                        seed: Optional[int] = None) -> np.ndarray:
            """
            Generate random wavenumber field (uniform distribution).
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                k_min: float
                    Minimum wavenumber (default: 0.5)
                k_max: float
                    Maximum wavenumber (default: 2.0)
                seed: int, optional
                    Random seed for reproducibility
                    
            Returns:
            --------
                k_field: np.ndarray (n_points,)
                    Wavenumber at each node
            """
            if seed is not None:
                np.random.seed(seed)
            return np.random.uniform(k_min, k_max, points.shape[0])
        
        @staticmethod
        def piecewise_constant(points: np.ndarray, 
                              regions: List[Callable], 
                              k_values: List[float], 
                              k_default: float = 1.0) -> np.ndarray:
            """
            Generate piecewise constant wavenumber field based on regions.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                regions: List[Callable]
                    Each callable takes (x, y) and returns boolean mask.
                    Example: lambda x, y: (x < 0.5) & (y < 0.5)
                k_values: List[float]
                    Wavenumber value for each region
                k_default: float
                    Default wavenumber for points not in any region (default: 1.0)
                    
            Returns:
            --------
                k_field: np.ndarray (n_points,)
                    Wavenumber at each node
            """
            x, y = points[:, 0], points[:, 1]
            k_field = np.full(points.shape[0], k_default)
            
            for region_fn, k_val in zip(regions, k_values):
                mask = region_fn(x, y)
                k_field[mask] = k_val
            
            return k_field
    
    class Random:
        """Random data generators for Helmholtz equation"""
        
        @staticmethod
        def source(points: np.ndarray, 
                  low: float = -1.0, 
                  high: float = 1.0) -> np.ndarray:
            """
            Generate random source function values at mesh nodes.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                low: float
                    Lower bound for random values (default: -1.0)
                high: float
                    Upper bound for random values (default: 1.0)
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            return np.random.uniform(low, high, points.shape[0])
        
        @staticmethod
        def source_sinusoidal(points: np.ndarray, 
                             a: Optional[np.ndarray] = None, 
                             K: int = 4, 
                             amplitude: float = 1.0) -> np.ndarray:
            """
            Generate sinusoidal source function.
            
            f(x,y) = Σ_{i,j} a_ij * sin(πi*x) * sin(πj*y)
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Mesh node coordinates
                a: np.ndarray (K, K), optional
                    Coefficients. If None, random coefficients are generated.
                K: int
                    Number of Fourier modes in each direction (default: 4)
                amplitude: float
                    Scale factor for random coefficients (default: 1.0)
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            if a is None:
                a = np.random.uniform(-amplitude, amplitude, (K, K))
            
            x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]
            j, i = np.meshgrid(np.arange(1, K + 1), np.arange(1, K + 1))
            
            f = (a * np.sin(np.pi * i * x) * np.sin(np.pi * j * y)).sum(axis=(-2, -1))
            return f
        
        @staticmethod
        def solution(mesh, 
                    f: np.ndarray, 
                    k: Union[float, np.ndarray], 
                    damping: float = 0.0) -> np.ndarray:
            """
            Solve Helmholtz equation using finite element method.
            
            Solves: (A - k²M) u = b
            
            where:
                A_IJ = ∫ ∇N^I · ∇N^J dx (stiffness matrix)
                M_IJ = ∫ N^I · N^J dx (mass matrix)
                b_I = ∫ f · N^I dx (load vector)
            
            Parameters:
            -----------
                mesh: meshio.Mesh
                    The finite element mesh with boundary conditions.
                    Must have 'boundary_mask' and 'boundary_value' in point_data.
                f: np.ndarray (n_points,)
                    Source function values at mesh nodes
                k: float or np.ndarray (n_points,)
                    Wavenumber (constant or spatially varying)
                damping: float
                    Small damping term to avoid resonance (default: 0.0).
                    Adds damping * M to system matrix.
                    Useful when k² is close to eigenvalues of A.
                    
            Returns:
            --------
                u: np.ndarray (n_points,)
                    Solution values at mesh nodes
            """
            # Validate mesh data
            assert "boundary_mask" in mesh.point_data, "Mesh must have boundary_mask in point_data"
            assert "boundary_value" in mesh.point_data, "Mesh must have boundary_value in point_data"
            
            points = mesh.points
            
            # Determine element type and setup quadrature
            if "triangle" in mesh.cells_dict:
                elements = mesh.cells_dict['triangle']
                qpoints = tri_gauss_points(ngp=4)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            elif "quad" in mesh.cells_dict:
                elements = mesh.cells_dict['quad']
                qpoints = quad_gauss_points(ngp=4)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            else:
                raise NotImplementedError("Mesh should have triangle or quad cells")
            
            boundary_mask = mesh.point_data['boundary_mask'].astype(bool)
            boundary_value = mesh.point_data['boundary_value']
            num_points = points.shape[0]
            num_elements, num_basis = elements.shape
            
            # Check if k is constant or spatial
            is_k_spatial = isinstance(k, np.ndarray) and k.size > 1
            
            JxW = np.abs(jac_det) * quadrature_weight  # [n_elements, n_quadrature]
            
            # Create element-to-edge mapping for matrix assembly
            elem_u, elem_v = [], []
            for i in range(num_basis):
                for j in range(num_basis):
                    elem_u.append(elements[:, i])
                    elem_v.append(elements[:, j])
            elem_u = np.stack(elem_u, -1).ravel()
            elem_v = np.stack(elem_v, -1).ravel()
            
            # Create edge index mapping
            tmp = scipy.sparse.coo_matrix(
                (np.ones_like(elem_u), (elem_u, elem_v)),
                shape=(num_points, num_points)
            ).tocsr().tocoo()
            edge_u, edge_v = tmp.row, tmp.col
            num_edges = len(edge_u)
            
            eids_csr = scipy.sparse.coo_matrix(
                (np.arange(num_edges), (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            elem_eids = np.array(eids_csr[elem_u, elem_v]).ravel()
            ele2msh_edge = scipy.sparse.coo_matrix(
                (np.ones_like(elem_eids), (elem_eids, np.arange(num_elements * num_basis * num_basis))),
                shape=(num_edges, num_elements * num_basis * num_basis)
            ).astype(np.float64)
            
            # Element to node mapping for load vector assembly
            ele2msh_node = scipy.sparse.coo_matrix(
                (np.ones([num_elements * num_basis]),
                 (elements.ravel(), np.arange(num_elements * num_basis))),
                shape=(num_points, num_elements * num_basis)
            ).tocsr()
            
            # Compute element contributions
            # Stiffness: K_ij = ∫ ∇N^i · ∇N^j dx
            K_elem = np.einsum("eqib,eqjb,eq->eqij", shape_grad, shape_grad, JxW)
            K_elem = K_elem.sum(1)  # [n_elements, n_basis, n_basis]
            
            # Mass: M_ij = ∫ N^i · N^j dx
            M_elem = np.einsum("qi,qj,eq->eij", shape_val, shape_val, JxW)
            
            # Load: F_i = ∫ f · N^i dx
            f_elem = f[elements]  # [n_elements, n_basis]
            f_quad = np.einsum("qi,ei->eq", shape_val, f_elem)  # f at quadrature points
            F_elem = np.einsum("eq,qi,eq->ei", f_quad, shape_val, JxW)
            
            # Assemble global matrices
            K_global = ele2msh_edge @ K_elem.ravel()  # [n_edges]
            M_global = ele2msh_edge @ M_elem.ravel()  # [n_edges]
            F_global = ele2msh_node @ F_elem.ravel()  # [n_nodes]
            
            # Handle k² term
            if is_k_spatial:
                # Spatially varying k: need weighted mass matrix
                k2_elem = (k[elements] ** 2)  # [n_elements, n_basis]
                k2_quad = np.einsum("qi,ei->eq", shape_val, k2_elem)  # k² at quadrature points
                # Weighted mass: M_k_ij = ∫ k²(x) N^i · N^j dx
                M_k_elem = np.einsum("eq,qi,qj,eq->eij", k2_quad, shape_val, shape_val, JxW)
                M_k_global = ele2msh_edge @ M_k_elem.ravel()
            else:
                # Constant k
                k2 = k ** 2
                M_k_global = k2 * M_global
            
            # Build system matrix: A - k²M (+ damping * M if specified)
            system_global = K_global - M_k_global
            if damping > 0:
                system_global = system_global + damping * M_global
            
            # Apply boundary conditions using static condensation
            is_inner_node = ~boundary_mask
            is_outer_node = boundary_mask
            is_inner_u = is_inner_node[edge_u]
            is_inner_v = is_inner_node[edge_v]
            is_inner_edge = is_inner_u & is_inner_v
            is_ou2in_edge = is_inner_u & is_outer_node[edge_v]
            
            n_inner_nodes = is_inner_node.sum()
            n_outer_nodes = is_outer_node.sum()
            
            # Create local node IDs
            local_nids = np.full((num_points,), -1, dtype=np.int64)
            local_nids[is_inner_node] = np.arange(n_inner_nodes)
            local_nids[is_outer_node] = np.arange(n_outer_nodes)
            
            # Extract matrix blocks
            K_inner = scipy.sparse.coo_matrix(
                (system_global[is_inner_edge], 
                 (local_nids[edge_u[is_inner_edge]], local_nids[edge_v[is_inner_edge]])),
                shape=(n_inner_nodes, n_inner_nodes)
            ).tocsr()
            
            K_ou2in = scipy.sparse.coo_matrix(
                (system_global[is_ou2in_edge], 
                 (local_nids[edge_u[is_ou2in_edge]], local_nids[edge_v[is_ou2in_edge]])),
                shape=(n_inner_nodes, n_outer_nodes)
            ).tocsr()
            
            F_condensed = F_global[is_inner_node] - K_ou2in @ boundary_value[is_outer_node]
            
            # Solve the linear system
            u_inner = scipy.sparse.linalg.spsolve(K_inner, F_condensed)
            
            # Assemble full solution
            U = np.zeros((num_points,))
            U[is_outer_node] = boundary_value[is_outer_node]
            U[is_inner_node] = u_inner
            
            return U

