import numpy as np
from typing import Callable, Optional, Tuple, Union
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import spsolve
from skfem import Basis, LinearForm, BilinearForm, DiscreteField, asm


class WeakForm:
    """
    Class for defining weak formulations of PDEs.
    
    Attributes
    ----------
    bilinear_form : Callable
        The bilinear form a(u,v) of the PDE
    linear_form : Optional[Callable]
        The linear form L(v) of the PDE (right-hand side)
    """
    
    def __init__(
        self, 
        bilinear_form: Callable, 
        linear_form: Optional[Callable] = None
    ):
        """
        Initialize the weak form.
        
        Parameters
        ----------
        bilinear_form : Callable
            Function defining the bilinear form a(u,v) in the weak formulation
        linear_form : Optional[Callable], default=None
            Function defining the linear form L(v) in the weak formulation
        """
        self.bilinear_form = bilinear_form
        self.linear_form = linear_form
    
    def assemble_matrix(self, basis: Basis) -> csr_matrix:
        """
        Assemble the system matrix corresponding to the bilinear form.
        
        Parameters
        ----------
        basis : Basis
            The finite element basis
            
        Returns
        -------
        csr_matrix
            The assembled system matrix
        """
        return asm(BilinearForm(self.bilinear_form), basis)
    
    def assemble_vector(self, basis: Basis) -> np.ndarray:
        """
        Assemble the right-hand side vector corresponding to the linear form.
        
        Parameters
        ----------
        basis : Basis
            The finite element basis
            
        Returns
        -------
        np.ndarray
            The assembled right-hand side vector
        """
        if self.linear_form is None:
            return np.zeros(basis.N)
        else:
            return asm(LinearForm(self.linear_form), basis)


class DirichletBC:
    """
    Class for Dirichlet boundary conditions.
    
    Attributes
    ----------
    value_func : Callable
        Function that returns the Dirichlet boundary value at a given point
    """
    
    def __init__(self, value_func: Callable):
        """
        Initialize Dirichlet boundary condition.
        
        Parameters
        ----------
        value_func : Callable
            Function that takes a coordinate array and returns the boundary value
        """
        self.value_func = value_func
    
    def apply(
        self, 
        basis: Basis, 
        A: csr_matrix, 
        b: np.ndarray
    ) -> Tuple[csr_matrix, np.ndarray]:
        """
        Apply the Dirichlet boundary conditions to the linear system.
        
        Parameters
        ----------
        basis : Basis
            The finite element basis
        A : csr_matrix
            The system matrix
        b : np.ndarray
            The right-hand side vector
            
        Returns
        -------
        Tuple[csr_matrix, np.ndarray]
            The modified system matrix and right-hand side vector
        """
        # Find nodes on the boundary (square domain [0,1]×[0,1])
        x = basis.mesh.p
        boundary_mask = (np.isclose(x[0, :], 0.0) | 
                         np.isclose(x[0, :], 1.0) | 
                         np.isclose(x[1, :], 0.0) | 
                         np.isclose(x[1, :], 1.0))
        boundary_dofs = np.where(boundary_mask)[0]
        
        # Apply boundary values
        u_vals = self.value_func(x)
        
        # Modify the linear system
        A = A.copy()
        b = b.copy()
        
        # Set rows and columns for boundary DOFs
        for dof in boundary_dofs:
            # Zero out the row
            for i in range(A.indptr[dof], A.indptr[dof + 1]):
                A.data[i] = 0.0
            
            # Set diagonal entry to 1
            A[dof, dof] = 1.0
            
            # Set right-hand side value
            if hasattr(u_vals, '__len__') and len(u_vals) > 1:
                b[dof] = u_vals[dof]
            else:
                b[dof] = u_vals
        
        return A, b


class PDESolver:
    """
    Class for solving PDEs using the finite element method.
    
    Attributes
    ----------
    weak_form : WeakForm
        The weak formulation of the PDE
    """
    
    def __init__(self, weak_form: WeakForm):
        """
        Initialize the PDE solver.
        
        Parameters
        ----------
        weak_form : WeakForm
            The weak formulation of the PDE
        """
        self.weak_form = weak_form
    
    def assemble_system(
        self, 
        basis: Basis, 
        bc: Optional[DirichletBC] = None
    ) -> Tuple[csr_matrix, np.ndarray]:
        """
        Assemble the linear system for the PDE.
        
        Parameters
        ----------
        basis : Basis
            The finite element basis
        bc : Optional[DirichletBC], default=None
            Dirichlet boundary conditions
            
        Returns
        -------
        Tuple[csr_matrix, np.ndarray]
            The assembled system matrix and right-hand side vector
        """
        # Assemble system matrix and vector
        A = self.weak_form.assemble_matrix(basis)
        b = self.weak_form.assemble_vector(basis)
        
        # Apply boundary conditions if provided
        if bc is not None:
            A, b = bc.apply(basis, A, b)
        
        return A, b
    
    def solve(
        self, 
        basis: Basis, 
        bc: Optional[DirichletBC] = None
    ) -> Tuple[DiscreteField, np.ndarray]:
        """
        Solve the PDE.
        
        Parameters
        ----------
        basis : Basis
            The finite element basis
        bc : Optional[DirichletBC], default=None
            Dirichlet boundary conditions
            
        Returns
        -------
        Tuple[DiscreteField, np.ndarray]
            The solution as a DiscreteField and the solution array
        """
        # Assemble the system
        A, b = self.assemble_system(basis, bc)
        
        # Solve the linear system
        solution_array = spsolve(A, b)
        
        # Create and return solution field and array
        solution_field = DiscreteField(basis, solution_array)
        
        # Create a wrapper object with solution components
        class Solution:
            def __init__(self, field, array, mesh):
                self.field = field
                self.value = array
                self.mesh = mesh
                
            def __getattr__(self, name):
                return getattr(self.field, name)
        
        return Solution(solution_field, solution_array, basis.mesh) 