"""
Poisson equation data generators for TensorGalerkin
"""

import numpy as np
import scipy.sparse
import scipy.sparse.linalg
from typing import Union, Tuple, Optional
from ...discretization import tri3, quad4, tri_gauss_points, quad_gauss_points


class PoissonGen:
    """Data generators for Poisson equation problems"""
    
    class MultiAnalytical:
        """Multi-mode analytical solutions for Poisson equation"""
        
        @staticmethod
        def source(points: np.ndarray, a: np.ndarray, r: float = 2.0) -> np.ndarray:
            """
            Generate the Poisson source function at each point in the domain
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                a: np.ndarray (K, K) or (N, K, K)
                    Coefficients for the analytical solution
                r: float
                    Power parameter for the solution
                    
            Returns:
            --------
                f: np.ndarray (n_points) or (N, n_points)
                    Source function values
            """
            assert len(a.shape) == 2 or len(a.shape) == 3, f"Shape of a should be (K, K) or (N, K, K), got {a.shape}"
            assert a.shape[-1] == a.shape[-2], f"Shape of a should be (K, K) or (N, K, K), got {a.shape}"
            
            K = a.shape[-1]
            j, i = np.meshgrid(np.arange(1, K+1), np.arange(1, K+1))  # (K, K)
            coeff = a * (i*i + j*j)**r
            
            if len(a.shape) == 2:
                coeff = coeff[None, ...]  # (1, K, K)
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]  # (n_points, 1, 1)
            else:
                coeff = coeff[:, None, ...]  # (N, 1, K, K)
                x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]  # (1, n_points, 1, 1)
            
            f = np.pi/K/K * (coeff * np.sin(np.pi*i*x) * np.sin(np.pi*j*y)).sum((-2, -1))
            return f
        
        @staticmethod
        def solution(points: np.ndarray, a: np.ndarray, r: float = 2.0) -> np.ndarray:
            """
            Generate the analytical solution for Poisson equation
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                a: np.ndarray (K, K) or (N, K, K)
                    Coefficients for the analytical solution
                r: float
                    Power parameter for the solution
                    
            Returns:
            --------
                u: np.ndarray (n_points) or (N, n_points)
                    Analytical solution values
            """
            assert len(a.shape) == 2 or len(a.shape) == 3, f"Shape of a should be (K, K) or (N, K, K), got {a.shape}"
            assert a.shape[-1] == a.shape[-2], f"Shape of a should be (K, K) or (N, K, K), got {a.shape}"
            
            K = a.shape[-1]
            j, i = np.meshgrid(np.arange(1, K+1), np.arange(1, K+1))  # (K, K)
            coeff = a * (i*i + j*j)**(r-1)
            
            if len(a.shape) == 2:
                coeff = coeff[None, ...]  # (1, K, K)
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]  # (n_points, 1, 1)
            else:
                coeff = coeff[:, None, ...]  # (N, 1, K, K)
                x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]  # (1, n_points, 1, 1)
            
            u = 1/np.pi/K/K * (coeff * np.sin(np.pi*i*x) * np.sin(np.pi*j*y)).sum((-2, -1))
            return u
    
    class Discontinuous:
        """Discontinuous source functions.
        
        These functions have low regularity (non-smooth or discontinuous),
        which makes them challenging for PINN but should be handled well by
        Galerkin methods that use weak formulations.
        """
        
        @staticmethod
        def abs_sin(points: np.ndarray, k: int = 2) -> np.ndarray:
            """
            f = |sin(k*pi*x)| * |sin(k*pi*y)|
            
            Continuous but derivative discontinuous at x = n/k and y = m/k.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                k: int
                    Frequency parameter (number of half-periods in [0,1])
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            x, y = points[:, 0], points[:, 1]
            return np.abs(np.sin(k * np.pi * x)) * np.abs(np.sin(k * np.pi * y))
        
        @staticmethod
        def step(points: np.ndarray, x0: float = 0.5) -> np.ndarray:
            """
            Step function: f = 1 if x > x0 else -1
            
            Discontinuous along the line x = x0.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                x0: float
                    Location of the step discontinuity
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            x = points[:, 0]
            return np.where(x > x0, 1.0, -1.0)
        
        @staticmethod
        def indicator_circle(points: np.ndarray, cx: float = 0.5, cy: float = 0.5, r: float = 0.25) -> np.ndarray:
            """
            Indicator function of a circle: f = 1 inside, 0 outside
            
            Discontinuous along the circle boundary.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                cx, cy: float
                    Center of the circle
                r: float
                    Radius of the circle
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            x, y = points[:, 0], points[:, 1]
            dist_sq = (x - cx)**2 + (y - cy)**2
            return np.where(dist_sq < r**2, 1.0, 0.0)
        
        @staticmethod
        def checkerboard(points: np.ndarray, n: int = 4) -> np.ndarray:
            """
            Checkerboard pattern: alternating +1/-1
            
            Discontinuous along grid lines, creating n x n cells.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                n: int
                    Number of divisions in each direction
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            x, y = points[:, 0], points[:, 1]
            ix = np.floor(n * x).astype(int)
            iy = np.floor(n * y).astype(int)
            return np.where((ix + iy) % 2 == 0, 1.0, -1.0)
        
        @staticmethod
        def constant(points: np.ndarray, value: float = 1.0) -> np.ndarray:
            """
            Constant source: f = value everywhere.
            
            Useful for testing domain singularities (e.g., L-shape corner)
            without source function complexity.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                value: float
                    Constant value
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Source function values
            """
            return np.full(len(points), value)
    
    class Random:
        """Random data generators for Poisson equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, 
                            amplitude: float = 1.0,
                            num_modes: int = 5) -> np.ndarray:
            """
            Generate random initial condition for Poisson problem
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                amplitude: float
                    Maximum amplitude of the initial condition
                num_modes: int
                    Number of Fourier modes to include
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Random initial condition
            """
            x, y = points[:, 0], points[:, 1]
            u0 = np.zeros(len(points))
            
            # Generate random Fourier coefficients
            for i in range(1, num_modes + 1):
                for j in range(1, num_modes + 1):
                    # Random amplitude and phase
                    amp = amplitude * np.random.uniform(-1, 1) / (i*i + j*j)
                    phase_x = np.random.uniform(0, 2*np.pi)
                    phase_y = np.random.uniform(0, 2*np.pi)
                    
                    # Add Fourier mode
                    u0 += amp * np.sin(i * np.pi * x + phase_x) * np.sin(j * np.pi * y + phase_y)
            
            return u0
        
        @staticmethod
        def source_function(points: np.ndarray,
                          amplitude: float = 1.0,
                          num_modes: int = 3) -> np.ndarray:
            """
            Generate random source function for Poisson equation
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                amplitude: float
                    Maximum amplitude of the source
                num_modes: int
                    Number of modes in the source
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Random source function
            """
            x, y = points[:, 0], points[:, 1]
            f = np.zeros(len(points))
            
            # Generate random source with fewer modes than solution
            for i in range(1, num_modes + 1):
                for j in range(1, num_modes + 1):
                    # Random amplitude
                    amp = amplitude * np.random.uniform(-1, 1)
                    
                    # Add source mode
                    f += amp * np.sin(i * np.pi * x) * np.sin(j * np.pi * y)
            
            return f
        
        @staticmethod
        def gaussian_source(points: np.ndarray,
                          centers: Union[np.ndarray, int] = None,
                          amplitudes: Union[np.ndarray, float] = 1.0,
                          widths: Union[np.ndarray, float] = 0.1) -> np.ndarray:
            """
            Generate Gaussian source functions
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                centers: np.ndarray (n_centers, 2) or int
                    Centers of Gaussian sources, or number of random centers
                amplitudes: np.ndarray (n_centers,) or float
                    Amplitudes of Gaussian sources
                widths: np.ndarray (n_centers,) or float
                    Widths of Gaussian sources
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Gaussian source function
            """
            if isinstance(centers, int):
                # Generate random centers
                n_centers = centers
                centers = np.random.uniform(0.1, 0.9, (n_centers, 2))
            elif centers is None:
                # Default: single center
                centers = np.array([[0.5, 0.5]])
            
            if np.isscalar(amplitudes):
                amplitudes = np.full(len(centers), amplitudes)
            if np.isscalar(widths):
                widths = np.full(len(centers), widths)
            
            f = np.zeros(len(points))
            
            for i, (center, amp, width) in enumerate(zip(centers, amplitudes, widths)):
                # Distance from center
                dist_sq = np.sum((points - center)**2, axis=1)
                
                # Add Gaussian
                f += amp * np.exp(-dist_sq / (2 * width**2))
            
            return f
        
        @staticmethod
        def source(points: np.ndarray, low: float = -1.0, high: float = 1.0) -> np.ndarray:
            """
            Generate uniform random source function
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                low: float
                    Lower bound of the source function
                high: float
                    Upper bound of the source function
                    
            Returns:
            --------
                f: np.ndarray (n_points,)
                    Random source function values
            """
            f = np.random.uniform(low, high, (points.shape[0],))
            return f
        
        @staticmethod
        def solution(mesh, f: np.ndarray, sigma: float = 1.0) -> np.ndarray:
            """
            Solve Poisson equation using finite element method
            
            Solves: -σ∇²u = f with specified boundary conditions
            
            Parameters:
            -----------
                mesh: meshio.Mesh
                    The computational mesh with boundary conditions
                f: np.ndarray (n_points,)
                    Source function values at mesh nodes
                sigma: float
                    Diffusion coefficient (default: 1.0)
                    
            Returns:
            --------
                u: np.ndarray (n_points,)
                    Finite element solution
            """
            # Validate mesh data
            assert "boundary_mask" in mesh.point_data, "Mesh must have boundary_mask in point_data"
            assert "boundary_value" in mesh.point_data, "Mesh must have boundary_value in point_data"
            
            points = mesh.points
            boundary_mask = mesh.point_data['boundary_mask']
            boundary_value = mesh.point_data['boundary_value']
            num_points = points.shape[0]
            
            # Determine element type and setup quadrature
            if "triangle" in mesh.cells_dict:
                elements = mesh.cells_dict['triangle']
                qpoints = tri_gauss_points(ngp=3)  # Use 1-point quadrature for simplicity
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                # Extract only x,y coordinates (in case mesh has z coordinate)
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            elif "quad" in mesh.cells_dict:
                elements = mesh.cells_dict['quad']
                qpoints = quad_gauss_points(ngp=1)  # Use 1x1 quadrature for simplicity
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                # Extract only x,y coordinates (in case mesh has z coordinate)
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            else:
                raise NotImplementedError("Mesh should have triangle or quadrilateral cells")
            
            num_elements, num_basis = elements.shape
            
            # Create element-to-node mapping for assembly
            ele2msh_node = scipy.sparse.coo_matrix((
                np.ones([num_elements * num_basis]),  # data
                (elements.ravel(), np.arange(num_elements * num_basis)),  # (row, col)
            ), shape=(num_points, num_elements * num_basis)).tocsr()
            
            # Create element-to-edge mapping for stiffness matrix assembly
            elem_u, elem_v = [], []
            for i in range(num_basis):
                for j in range(num_basis):
                    elem_u.append(elements[:, i])
                    elem_v.append(elements[:, j])
            elem_u, elem_v = np.stack(elem_u, -1).ravel(), np.stack(elem_v, -1).ravel()
            
            # Remove duplicated edges
            tmp = scipy.sparse.coo_matrix((
                np.ones_like(elem_u),  # data
                (elem_u, elem_v),  # (row, col)
            ), shape=(num_points, num_points)).tocsr().tocoo()
            edge_u, edge_v = tmp.row, tmp.col
            num_edges = len(edge_u)
            
            eids_csr = scipy.sparse.coo_matrix((
                np.arange(num_edges), (edge_u, edge_v)
            ), shape=(num_points, num_points)).tocsr()
            
            elem_eids = np.array(eids_csr[elem_u, elem_v]).ravel()
            ele2msh_edge = scipy.sparse.coo_matrix((
                np.ones_like(elem_eids),
                (elem_eids, np.arange(num_elements * num_basis * num_basis))
            ), shape=(num_edges, num_elements * num_basis * num_basis)).astype(np.float64)
            
            # Compute Jacobian weighted quadrature weights
            JxW = jac_det * quadrature_weight  # [num_elements, num_quadrature_points]
            
            # Assemble load vector
            f_elem = f[elements]  # [num_elements, num_basis]
            f_quad = shape_val[None, :, :] * f_elem[:, None, :]  # [num_elements, num_quadrature_points, num_basis]
            f_quad = f_quad * JxW[:, :, None]  # [num_elements, num_quadrature_points, num_basis]
            F_elem = f_quad.sum(1)  # [num_elements, num_basis]
            
            # Assemble stiffness matrix
            K_elem = sigma * np.einsum("eqib,eqjb,eq->eij", shape_grad, shape_grad, JxW)  # [num_elements, num_basis, num_basis]
            
            # Global assembly
            F = ele2msh_node @ F_elem.ravel()  # [num_nodes]
            K = ele2msh_edge @ K_elem.ravel()  # [num_edges]
            
            # Apply boundary conditions using static condensation
            is_inner_node = ~boundary_mask
            is_outer_node = boundary_mask
            is_inner_u = is_inner_node[edge_u]
            is_inner_v = is_inner_node[edge_v]
            is_outer_u = is_outer_node[edge_u]
            is_outer_v = is_outer_node[edge_v]
            is_inner_edge = is_inner_u & is_inner_v
            is_ou2in_edge = is_inner_u & is_outer_v
            
            n_inner_nodes = is_inner_node.sum()
            n_outer_nodes = is_outer_node.sum()
            
            # Create local node IDs
            local_nids = np.full((num_points,), -1, dtype=np.int64)
            local_nids[is_inner_node] = np.arange(n_inner_nodes)
            local_nids[is_outer_node] = np.arange(n_outer_nodes)
            
            # Extract matrix blocks
            K_inner = K[is_inner_edge]
            K_ou2in = K[is_ou2in_edge]
            
            K_inner = scipy.sparse.coo_matrix((
                K_inner.ravel(), (local_nids[edge_u[is_inner_edge]], local_nids[edge_v[is_inner_edge]])
            ), shape=(n_inner_nodes, n_inner_nodes)).tocsr()
            
            K_ou2in = scipy.sparse.coo_matrix((
                K_ou2in.ravel(), (local_nids[edge_u[is_ou2in_edge]], local_nids[edge_v[is_ou2in_edge]])
            ), shape=(n_inner_nodes, n_outer_nodes)).tocsr()
            
            # Modify RHS for boundary conditions
            F_inner = F[is_inner_node] - K_ou2in @ boundary_value[is_outer_node]
            
            # Solve linear system
            u_inner = scipy.sparse.linalg.spsolve(K_inner, F_inner)
            
            # Assemble full solution
            U = np.zeros((num_points,))
            U[is_outer_node] = boundary_value[is_outer_node]
            U[is_inner_node] = u_inner
            
            return U