import pytest
import numpy as np
from skfem import ElementTriP1, Basis
from skfem.helpers import dot, grad

from mesh_opt.mesh_utils import create_uniform_mesh, mesh_to_coords, coords_to_mesh
from mesh_opt.pde_solver import WeakForm, PDESolver, DirichletBC
from mesh_opt.error_metrics import compute_mse
from mesh_opt.trainer import TrainerConfig, Trainer


class TestTrainer:
    @pytest.fixture
    def mesh(self):
        """Create a mesh for testing."""
        return create_uniform_mesh(nx=3, ny=3)
    
    @pytest.fixture
    def poisson_problem(self, mesh):
        """Create a simple Poisson problem for testing."""
        # Weak form for Poisson equation
        def poisson_form(u, v, w):
            return dot(grad(u), grad(v))
        
        def load_form(v, w):
            return 1.0 * v
        
        weak_form = WeakForm(bilinear_form=poisson_form, linear_form=load_form)
        
        # Boundary condition
        def zero_bc(x):
            return 0.0
        
        bc = DirichletBC(zero_bc)
        
        # Solver
        solver = PDESolver(weak_form)
        
        # Ground truth function
        def ground_truth(points):
            x, y = points[:, 0], points[:, 1]
            return 4 * x * (1 - x) * y * (1 - y)
        
        return {
            'weak_form': weak_form,
            'bc': bc,
            'solver': solver,
            'ground_truth': ground_truth
        }
    
    def test_trainer_config(self):
        """Test TrainerConfig initialization."""
        config = TrainerConfig(
            step_size=0.01,
            fd_radius=0.1,
            n_samples=10,
            regularization_weight=0.001
        )
        
        assert config.step_size == 0.01
        assert config.fd_radius == 0.1
        assert config.n_samples == 10
        assert config.regularization_weight == 0.001
    
    def test_trainer_initialization(self, mesh, poisson_problem):
        """Test Trainer initialization and parameter vectorization."""
        config = TrainerConfig()
        trainer = Trainer(config, poisson_problem['solver'], poisson_problem['bc'], poisson_problem['ground_truth'])
        
        # Initialize with a mesh
        params = trainer.initialize(mesh)
        
        # Check params shape
        n_nodes = mesh.p.shape[1]
        assert params.shape == (n_nodes * 2,)  # x and y coordinates flattened
        
        # Verify params match mesh coordinates
        mesh_coords = mesh_to_coords(mesh)
        assert np.allclose(params[:n_nodes], mesh_coords[:, 0])  # x coordinates
        assert np.allclose(params[n_nodes:], mesh_coords[:, 1])  # y coordinates
    
    def test_trainer_vectorization(self, mesh):
        """Test conversion between parameter vector and mesh coordinates."""
        config = TrainerConfig()
        trainer = Trainer(config, None, None, None)  # No need for solver/bc/truth for this test
        
        # Convert mesh to params
        original_coords = mesh_to_coords(mesh)
        params = trainer._mesh_to_params(mesh)
        
        # Convert params back to mesh
        new_mesh = trainer._params_to_mesh(params, mesh)
        new_coords = mesh_to_coords(new_mesh)
        
        # Check they match
        assert np.allclose(original_coords, new_coords)
    
    def test_gradient_estimation(self, mesh, poisson_problem):
        """Test gradient estimation with finite differences."""
        config = TrainerConfig(n_samples=5, fd_radius=0.01)
        trainer = Trainer(config, poisson_problem['solver'], poisson_problem['bc'], poisson_problem['ground_truth'])
        
        # Initialize
        params = trainer.initialize(mesh)
        
        # Compute loss at initial point
        loss = trainer.compute_loss(params)
        assert isinstance(loss, float)
        
        # Estimate gradient
        grad = trainer.estimate_gradient(params)
        
        # Check gradient shape
        assert grad.shape == params.shape
        
        # Make sure at least some components are non-zero
        assert np.any(grad != 0)
        
        # Check gradient direction with finite differences
        fd_grads = []
        step = 1e-5
        for i in range(min(5, len(params))):  # Check a few random components
            params_plus = params.copy()
            params_plus[i] += step
            loss_plus = trainer.compute_loss(params_plus)
            
            params_minus = params.copy()
            params_minus[i] -= step
            loss_minus = trainer.compute_loss(params_minus)
            
            fd_grad = (loss_plus - loss_minus) / (2 * step)
            fd_grads.append(fd_grad)
        
        # The sign of the gradient should match the finite difference gradient
        # for at least some of the components (allowing for some randomness)
        fd_grad_array = np.array(fd_grads)
        grad_sample = grad[:len(fd_grad_array)]
        sign_match = (np.sign(fd_grad_array) == np.sign(grad_sample)) | (np.abs(fd_grad_array) < 1e-10)
        
        # At least some of the signs should match
        assert np.any(sign_match)
    
    def test_optimization_step(self, mesh, poisson_problem):
        """Test that a single optimization step reduces the loss."""
        config = TrainerConfig(step_size=0.01, n_samples=10)
        trainer = Trainer(config, poisson_problem['solver'], poisson_problem['bc'], poisson_problem['ground_truth'])
        
        # Initialize
        params = trainer.initialize(mesh)
        
        # Compute loss before step
        loss_before = trainer.compute_loss(params)
        
        # Take a single step
        new_params = trainer.step(params)
        
        # Compute loss after step
        loss_after = trainer.compute_loss(new_params)
        
        # The loss should decrease (allowing for some numerical error)
        assert loss_after <= loss_before * 1.001  # 0.1% tolerance for numerical issues
    
    def test_run_optimization(self, mesh, poisson_problem):
        """Test running the optimization loop for a few iterations."""
        config = TrainerConfig(step_size=0.01, n_samples=5, n_iterations=3)
        trainer = Trainer(config, poisson_problem['solver'], poisson_problem['bc'], poisson_problem['ground_truth'])
        
        # Run the optimization
        result = trainer.run(mesh)
        
        # Check the result
        assert 'params' in result
        assert 'loss_history' in result
        assert len(result['loss_history']) == config.n_iterations + 1  # Initial + each iteration
        
        # Loss should be non-increasing (allowing for some noise)
        for i in range(1, len(result['loss_history'])):
            assert result['loss_history'][i] <= result['loss_history'][i-1] * 1.05  # 5% tolerance 