import pytest
import numpy as np
from skfem import ElementTriP1, Basis

from mesh_opt.mesh_utils import create_uniform_mesh
from mesh_opt.pde_solver import WeakForm, PDESolver, DirichletBC
from mesh_opt.error_metrics import interpolate_solution, compute_mse
from skfem.helpers import dot, grad


class TestErrorMetrics:
    @pytest.fixture
    def mesh(self):
        """Create a mesh for testing."""
        return create_uniform_mesh(nx=5, ny=5)
    
    @pytest.fixture
    def basis(self, mesh):
        """Create a basis for testing."""
        return Basis(mesh, ElementTriP1())
    
    @pytest.fixture
    def solution(self, basis):
        """Create a simple solution field for testing."""
        # Define a simple PDE: -Δu = 1 with u = 0 on boundary
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            return 1.0 * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        # Boundary condition
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        
        # Solve the PDE
        solver = PDESolver(weak_form)
        return solver.solve(basis, bc)
    
    def test_solution_structure(self, solution):
        """Test the structure of the solution object."""
        # Check the basic attributes needed for interpolation
        assert hasattr(solution, 'field')
        assert hasattr(solution, 'value')
        assert hasattr(solution, 'mesh')
    
    def test_interpolate_solution(self, solution, basis):
        """Test that solution can be interpolated at arbitrary points."""
        # Create some test points inside the domain
        test_points = np.array([
            [0.3, 0.4],
            [0.7, 0.2],
            [0.5, 0.5]
        ])
        
        # Interpolate the solution at these points
        interpolated_values = interpolate_solution(solution, test_points)
        
        # Check shape and basic properties
        assert isinstance(interpolated_values, np.ndarray)
        assert interpolated_values.shape == (len(test_points),)
        
        # All values should be positive for this specific problem
        assert np.all(interpolated_values > 0)
    
    def test_compute_mse(self, solution, basis):
        """Test computing MSE against ground truth."""
        # Create synthetic ground truth data
        # For the Poisson equation -Δu = 1 with zero boundary conditions,
        # the solution is roughly parabolic, so we'll use a parabola for ground truth
        def ground_truth_func(points):
            x, y = points[:, 0], points[:, 1]
            return 4 * x * (1 - x) * y * (1 - y)
        
        # Generate some sample points inside the domain
        n_points = 20
        np.random.seed(42)  # For reproducibility
        sample_points = np.random.rand(n_points, 2) * 0.8 + 0.1  # Keep away from boundary
        
        # Compute ground truth values
        ground_truth = ground_truth_func(sample_points)
        
        # Compute MSE
        mse = compute_mse(solution, sample_points, ground_truth)
        
        # Basic assertions
        assert isinstance(mse, float)
        assert mse >= 0
        
        # The MSE should decrease with a more refined mesh
        refined_mesh = basis.mesh.refined()
        refined_basis = Basis(refined_mesh, ElementTriP1())
        
        # Resolve on refined mesh
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            return 1.0 * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        
        solver = PDESolver(weak_form)
        refined_solution = solver.solve(refined_basis, bc)
        
        # Compute MSE on refined mesh
        refined_mse = compute_mse(refined_solution, sample_points, ground_truth)
        
        # The refined mesh should give a better approximation
        assert refined_mse <= mse * 1.1  # Allow for some numerical error 