import pytest
import numpy as np
from skfem import MeshTri, ElementTriP1, Basis
from skfem.helpers import dot, grad
from scipy.sparse import csr_matrix

from mesh_opt.mesh_utils import create_uniform_mesh
from mesh_opt.pde_solver import WeakForm, PDESolver, DirichletBC


class TestPDESolver:
    @pytest.fixture
    def mesh(self):
        """Create a simple mesh for testing."""
        return create_uniform_mesh(nx=10, ny=10)
    
    @pytest.fixture
    def basis(self, mesh):
        """Create a finite element basis."""
        return Basis(mesh, ElementTriP1())
    
    def test_weak_form_poisson(self, basis):
        """Test weak form assembly for Poisson equation."""
        # Define weak form for Poisson equation: -Δu = f
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        weak_form = WeakForm(bilinear_form=poisson_form)
        
        # Assemble system matrix
        A = weak_form.assemble_matrix(basis)
        
        # Check matrix properties
        assert isinstance(A, csr_matrix)
        assert A.shape[0] == A.shape[1]  # Square matrix
        assert A.shape[0] == basis.N  # Size matches number of DOFs
        
        # Check symmetry
        assert np.allclose(A.todense(), A.todense().T, atol=1e-10)
        
        # Check positive definiteness (all eigenvalues are positive)
        from scipy.sparse.linalg import eigsh
        eigenvalues = eigsh(A, k=5, which='SM', return_eigenvectors=False)
        assert np.all(eigenvalues > 0)
    
    def test_dirichlet_bc(self, basis):
        """Test Dirichlet boundary condition application."""
        # Create a PDE solver
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        weak_form = WeakForm(bilinear_form=poisson_form)
        
        # Define simple Dirichlet BC: u = 1.0 on boundary
        def boundary_condition(x):
            return 1.0
        
        bc = DirichletBC(boundary_condition)
        
        # Assemble system
        solver = PDESolver(weak_form)
        A, b = solver.assemble_system(basis, bc)
        
        # Find boundary nodes using the same approach as in DirichletBC.apply
        x = basis.mesh.p
        boundary_mask = (np.isclose(x[0, :], 0.0) | 
                         np.isclose(x[0, :], 1.0) | 
                         np.isclose(x[1, :], 0.0) | 
                         np.isclose(x[1, :], 1.0))
        boundary_dofs = np.where(boundary_mask)[0]
        
        # Check boundary values in the right-hand side
        for dof in boundary_dofs:
            assert b[dof] == 1.0
    
    def test_solve_poisson(self, basis):
        """Test solving a simple Poisson problem with known analytic solution."""
        # Poisson problem: -Δu = 2π²sin(πx)sin(πy) with u = 0 on boundary
        # Analytic solution: u(x,y) = sin(πx)sin(πy)
        
        def source_term(x):
            return 2 * np.pi**2 * np.sin(np.pi * x[0]) * np.sin(np.pi * x[1])
        
        def analytic_solution(x):
            return np.sin(np.pi * x[0]) * np.sin(np.pi * x[1])
        
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            x = w.x
            return source_term(x) * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        # Zero Dirichlet BC
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        
        # Solve the PDE
        solver = PDESolver(weak_form)
        solution = solver.solve(basis, bc)
        
        # Get the solution vector
        u_h = solution.value
        
        # Evaluate analytic solution at mesh nodes
        x = basis.mesh.p
        u_exact = np.sin(np.pi * x[0, :]) * np.sin(np.pi * x[1, :])
        
        # Check that solution is reasonably close to analytic solution
        # (We don't expect perfect accuracy with a coarse mesh)
        assert isinstance(u_h, np.ndarray)
        error = np.max(np.abs(u_h - u_exact))
        assert error < 0.1  # Rough tolerance for coarse mesh 