import numpy as np
from skfem import MeshTri
from copy import deepcopy


def create_uniform_mesh(nx: int, ny: int) -> MeshTri:
    """
    Create a uniform triangular mesh on a unit square [0,1]×[0,1].
    
    Parameters
    ----------
    nx : int
        Number of elements in x-direction
    ny : int
        Number of elements in y-direction
        
    Returns
    -------
    MeshTri
        A triangular mesh
    """
    # Create a regular rectangular grid of points
    x = np.linspace(0, 1, nx + 1)
    y = np.linspace(0, 1, ny + 1)
    X, Y = np.meshgrid(x, y)
    
    # Flatten the grid to get coordinates of each point
    points = np.vstack((X.flatten(), Y.flatten()))
    
    # Define triangles (elements)
    triangles = []
    for j in range(ny):
        for i in range(nx):
            # Index of the point at (i,j)
            idx = j * (nx + 1) + i
            
            # Create two triangles for each grid cell
            triangles.append([idx, idx + 1, idx + nx + 1])
            triangles.append([idx + 1, idx + nx + 2, idx + nx + 1])
    
    # Convert to a numpy array
    triangles = np.array(triangles, dtype=int).T
    
    # Create skfem triangular mesh
    return MeshTri(points, triangles)


def refine_mesh(mesh: MeshTri) -> MeshTri:
    """
    Perform uniform refinement of a triangular mesh.
    
    Parameters
    ----------
    mesh : MeshTri
        Input mesh to refine
        
    Returns
    -------
    MeshTri
        Refined mesh
    """
    return mesh.refined()


def mesh_to_coords(mesh: MeshTri) -> np.ndarray:
    """
    Convert a mesh to a coordinate array.
    
    Parameters
    ----------
    mesh : MeshTri
        Input mesh
        
    Returns
    -------
    np.ndarray
        Array of shape (n_nodes, 2) containing (x,y) coordinates of each node
    """
    # mesh.p has shape (2, n_nodes), we want (n_nodes, 2)
    p0 = deepcopy(mesh.p[0, :])
    p1 = deepcopy(mesh.p[1, :])
    return np.vstack((p0, p1)).T


def coords_to_mesh(coords: np.ndarray, original_mesh: MeshTri) -> MeshTri:
    """
    Convert coordinate array back to a mesh, preserving the topology of the original mesh.
    
    Parameters
    ----------
    coords : np.ndarray
        Array of shape (n_nodes, 2) containing (x,y) coordinates of each node
    original_mesh : MeshTri
        Original mesh with the desired topology
        
    Returns
    -------
    MeshTri
        New mesh with updated node coordinates but same topology
    """
    # Create a new mesh with the same topology but updated coordinates
    new_points = np.vstack((coords[:, 0], coords[:, 1]))
    return MeshTri(new_points, original_mesh.t)


def load_mesh(filename: str) -> MeshTri:
    """
    Load a mesh from a file.
    
    Parameters
    ----------
    filename : str
        Path to the mesh file (npz format)
        
    Returns
    -------
    MeshTri
        Loaded mesh
    """
    data = np.load(filename)
    points = data['points']
    triangles = data['triangles']
    
    return MeshTri(points, triangles) 