import pytest
import numpy as np
import os
from skfem import ElementTriP1, Basis
from skfem.helpers import dot, grad

from mesh_opt.mesh_utils import create_uniform_mesh
from mesh_opt.pde_solver import WeakForm, PDESolver, DirichletBC
from mesh_opt.trainer import TrainerConfig, Trainer
from mesh_opt.reports import plot_mse_history, export_optimized_mesh, visualize_mesh_solution


class TestReports:
    @pytest.fixture
    def training_result(self):
        """Create a sample training result for testing."""
        # Create a simple mesh
        mesh = create_uniform_mesh(nx=3, ny=3)
        
        # Create a PDE problem
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            return 1.0 * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        solver = PDESolver(weak_form)
        
        def ground_truth(points):
            x, y = points[:, 0], points[:, 1]
            return 4 * x * (1 - x) * y * (1 - y)
        
        # Simulate a short training run
        config = TrainerConfig(n_iterations=2, n_samples=3)
        trainer = Trainer(config, solver, bc, ground_truth)
        
        return trainer.run(mesh)
    
    def test_plot_mse_history(self, training_result, tmpdir):
        """Test plotting MSE history."""
        # Create a temporary file for the plot
        output_file = os.path.join(tmpdir, "mse_history.png")
        
        # Generate the plot
        plot_mse_history(training_result['loss_history'], output_file)
        
        # Check that the file was created
        assert os.path.exists(output_file)
        assert os.path.getsize(output_file) > 0
    
    def test_export_optimized_mesh(self, training_result, tmpdir):
        """Test exporting an optimized mesh."""
        # Create a temporary file for the mesh
        output_file = os.path.join(tmpdir, "optimized_mesh.npz")
        
        # Export the mesh
        export_optimized_mesh(training_result['final_mesh'], output_file)
        
        # Check that the file was created
        assert os.path.exists(output_file)
        assert os.path.getsize(output_file) > 0
        
        # Verify the file contains the expected data
        data = np.load(output_file)
        assert 'points' in data
        assert 'triangles' in data
        assert data['points'].shape[0] == 2  # 2D mesh
        assert data['triangles'].shape[0] == 3  # Triangular elements
    
    def test_visualize_mesh_solution(self, training_result, tmpdir):
        """Test visualization of a solution on a mesh."""
        # Create a temporary file for the visualization
        output_file = os.path.join(tmpdir, "solution_visualization.png")
        
        # Create a basis and solve
        final_mesh = training_result['final_mesh']
        basis = Basis(final_mesh, ElementTriP1())
        
        # Create a PDE problem
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            return 1.0 * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        solver = PDESolver(weak_form)
        
        # Solve on optimized mesh
        solution = solver.solve(basis, bc)
        
        # Generate the visualization
        visualize_mesh_solution(final_mesh, solution, output_file)
        
        # Check that the file was created
        assert os.path.exists(output_file)
        assert os.path.getsize(output_file) > 0 