import pytest
import numpy as np
import os
import tempfile
from skfem import MeshTri

from mesh_opt.mesh_utils import load_mesh, create_uniform_mesh, refine_mesh, mesh_to_coords, coords_to_mesh


class TestMeshUtils:
    def test_create_uniform_mesh(self):
        """Test creating a uniform triangular mesh."""
        mesh = create_uniform_mesh(nx=5, ny=5)
        assert isinstance(mesh, MeshTri)
        assert mesh.p.shape[0] == 2  # 2D mesh
        assert mesh.p.shape[1] == 36  # (nx+1) * (ny+1) nodes
        assert mesh.t.shape[1] == 50  # 2 * nx * ny triangles

    def test_mesh_to_coords(self):
        """Test converting a mesh to coordinate array."""
        mesh = create_uniform_mesh(nx=3, ny=3)
        coords = mesh_to_coords(mesh)
        
        # Check shape and content
        assert isinstance(coords, np.ndarray)
        assert coords.shape == (mesh.p.shape[1], 2)  # Nx2 array of (x,y) coordinates
        assert np.allclose(coords[:, 0], mesh.p[0, :])  # x-coordinates match
        assert np.allclose(coords[:, 1], mesh.p[1, :])  # y-coordinates match

    def test_coords_to_mesh(self):
        """Test converting coordinate array back to mesh."""
        original_mesh = create_uniform_mesh(nx=3, ny=3)
        coords = mesh_to_coords(original_mesh)
        
        # Modify coordinates slightly
        modified_coords = coords.copy()
        modified_coords[5, 0] += 0.1  # Move node 5 in x-direction
        
        # Convert back to mesh
        new_mesh = coords_to_mesh(modified_coords, original_mesh)
        
        # Verify the mesh has the same topology
        assert np.array_equal(original_mesh.t, new_mesh.t)
        
        # Verify the coordinates were updated correctly
        assert np.allclose(new_mesh.p[0, 5], original_mesh.p[0, 5] + 0.1)
        assert np.allclose(new_mesh.p[1, 5], original_mesh.p[1, 5])
        
        # All other nodes should be unchanged
        mask = np.ones(coords.shape[0], dtype=bool)
        mask[5] = False
        assert np.allclose(new_mesh.p[0, mask], original_mesh.p[0, mask])
        assert np.allclose(new_mesh.p[1, mask], original_mesh.p[1, mask])

    def test_refine_mesh(self):
        """Test uniform mesh refinement."""
        mesh = create_uniform_mesh(nx=3, ny=3)
        refined_mesh = refine_mesh(mesh)
        
        # Refined mesh should have more elements and nodes
        assert refined_mesh.p.shape[1] > mesh.p.shape[1]
        assert refined_mesh.t.shape[1] > mesh.t.shape[1]

    def test_load_mesh(self):
        """Test loading a mesh from file."""
        # Create a temporary mesh file
        with tempfile.NamedTemporaryFile(suffix='.npz', delete=False) as tmp:
            tmp_filename = tmp.name
        
        mesh = create_uniform_mesh(nx=3, ny=3)
        np.savez(tmp_filename, 
                 points=mesh.p, 
                 triangles=mesh.t)
        
        # Load the mesh from file
        loaded_mesh = load_mesh(tmp_filename)
        
        # Verify loaded mesh matches the original
        assert isinstance(loaded_mesh, MeshTri)
        assert np.array_equal(loaded_mesh.p, mesh.p)
        assert np.array_equal(loaded_mesh.t, mesh.t)
        
        # Clean up
        os.remove(tmp_filename) 