"""
Simple mesh generation utilities without Gmsh dependency.

This module provides mesh generation using NumPy/tensor operations for 
structured meshes and special mesh loading from files.
"""

import meshio
import os
import numpy as np
import scipy.io as sio
from typing import Tuple, Optional, Any


class Tmsh:
    """
    Simple tensor-based mesh generator.
    
    Provides structured mesh generation without requiring Gmsh.
    """
    
    @staticmethod
    def gen_quad(xlims: Tuple[float, float] = (0., 1.), 
                 ylims: Tuple[float, float] = (0., 1.), 
                 nx: int = 10, 
                 ny: int = 10) -> meshio.Mesh:
        """
        Generate a structured quadrilateral mesh on a rectangular domain.
        
        Parameters:
        -----------
            xlims: tuple(float, float)
                The x-axis boundaries (default: (0, 1))
            ylims: tuple(float, float)
                The y-axis boundaries (default: (0, 1))
            nx: int
                Number of nodes in x direction (default: 10)
            ny: int
                Number of nodes in y direction (default: 10)
                
        Returns:
        --------
            mesh: meshio.Mesh
                Structured quadrilateral mesh with boundary_mask
        """
        y, x = np.meshgrid(
            np.linspace(xlims[0], xlims[1], nx), 
            np.linspace(xlims[0], xlims[1], ny)
        )

        nids = np.arange(nx * ny).reshape(ny, nx)
        boundary_mask = np.zeros_like(nids, dtype=bool)
        boundary_mask[0, :] = True
        boundary_mask[-1, :] = True
        boundary_mask[:, 0] = True
        boundary_mask[:, -1] = True
        
        mesh = meshio.Mesh(
            points=np.stack([x, y], axis=-1).reshape(-1, 2),
            cells={'quad': np.stack([
                nids[:-1, :-1],
                nids[1:, :-1],
                nids[1:, 1:],
                nids[:-1, 1:]
            ], axis=-1).reshape(-1, 4)},
            point_data={
                'boundary_mask': boundary_mask.ravel(),
            }
        )
        return mesh
    
    @staticmethod
    def gen_batman() -> meshio.Mesh:
        """
        Load the Batman mesh from a .mat file.
        
        Returns:
        --------
            mesh: meshio.Mesh
                The Batman-shaped triangular mesh
        """
        # Try different possible locations for the mesh file
        possible_paths = [
            "./mesh_file/batman.mat",
            "./Dataset/mesh_file/batman.mat",
            os.path.join(os.path.dirname(__file__), "../../../Dataset/mesh_file/batman.mat"),
        ]
        
        mat = None
        for path in possible_paths:
            if os.path.isfile(path):
                mat = sio.loadmat(path)
                break
        
        if mat is None:
            raise FileNotFoundError(
                f"Could not find batman.mat. Searched paths: {possible_paths}"
            )
        
        points = mat['xcg'].T
        # Center the mesh
        points = points - (points.max(axis=0) + points.min(axis=0)) / 2
        # Shift to the first quadrant
        points = points - points.min(axis=0)

        cells = {'triangle': mat['e2vcg'].T.astype(np.int64) - 1}
        boundary_mask = (points[0, 1] - points[:, 1]) < 1e-7
        mesh = meshio.Mesh(
            points=points, 
            cells=cells, 
            point_data={'boundary_mask': boundary_mask}
        )

        return mesh


def init_mesh(args: Any) -> meshio.Mesh:
    """
    Initialize a mesh based on the given arguments.

    This function dispatches to the appropriate mesh generator based on
    the element type and shape specified in the arguments.

    Args:
        args: Configuration object with mesh parameters:
            - element: str ("tri" or "quad")
            - shape: str ("rectangle", "circle", "ellipse", "L_shape", "batman")
            - xlims: Optional[Tuple[float, float]]
            - ylims: Optional[Tuple[float, float]]
            - radius: Optional[Union[float, Tuple[float, float]]]
            - center: Optional[Tuple[float, float]]
            - chara_length: float
            - grid: int (for structured meshes)
            - verbose: bool
            - use_free_boundary: bool

    Returns:
        mesh: meshio.Mesh
            The initialized mesh

    Raises:
        NotImplementedError: If the shape is not supported for the specified element.
        ValueError: If the element is unknown.
    """
    # Import Gmsh here to allow usage without Gmsh dependency for simple meshes
    from .gmsh_gen import Gmsh
    
    if args.element == "tri":
        if args.shape == "rectangle":
            kwargs = {}
            if args.xlims is not None:
                kwargs["xlims"] = args.xlims
            if args.ylims is not None:
                kwargs["ylims"] = args.ylims
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_tri_rectangle(**kwargs)
        elif args.shape == "circle":
            kwargs = {}
            if args.radius is not None:
                kwargs["radius"] = args.radius[0] if isinstance(args.radius, (list, tuple)) else args.radius
            if args.center is not None:
                kwargs["center"] = args.center
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_tri_circle(**kwargs)
        elif args.shape == "ellipse":
            kwargs = {}
            if args.radius is not None:
                kwargs["radius"] = args.radius
            if args.center is not None:
                kwargs["center"] = args.center
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_tri_ellipse(**kwargs)
        elif args.shape == "L_shape":
            kwargs = {}
            if args.xlims is not None:
                kwargs["xlims"] = args.xlims
            if args.ylims is not None:
                kwargs["ylims"] = args.ylims
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_tri_L_shape(**kwargs)
        elif args.shape == "batman":
            mesh = Tmsh.gen_batman()
        else:
            raise NotImplementedError(f"Shape {args.shape} is not supported for tri element")
    elif args.element == "quad":
        if args.shape == "rectangle":
            kwargs = {}
            if args.xlims is not None:
                kwargs["xlims"] = args.xlims
            if args.ylims is not None:
                kwargs["ylims"] = args.ylims
            kwargs["nx"] = args.grid
            kwargs["ny"] = args.grid
            mesh = Tmsh.gen_quad(**kwargs)
        elif args.shape == "circle":
            kwargs = {}
            if args.radius is not None:
                kwargs["radius"] = args.radius[0] if isinstance(args.radius, (list, tuple)) else args.radius
            if args.center is not None:
                kwargs["center"] = args.center
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_quad_cirlce(**kwargs)
        elif args.shape == "ellipse":
            kwargs = {}
            if args.radius is not None:
                kwargs["radius"] = args.radius
            if args.center is not None:
                kwargs["center"] = args.center
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_quad_ellipse(**kwargs)
        elif args.shape == "L_shape":
            kwargs = {}
            if args.xlims is not None:
                kwargs["xlims"] = args.xlims
            if args.ylims is not None:
                kwargs["ylims"] = args.ylims
            kwargs["chara_length"] = args.chara_length
            kwargs["verbose"] = args.verbose
            mesh = Gmsh.gen_quad_L_shape(**kwargs)
        else:
            raise NotImplementedError(f"Shape {args.shape} is not supported for quad element")
    else:
        raise ValueError(f"Unknown element {args.element}")
    
    if args.use_free_boundary:
        mesh.point_data['boundary_mask'] = np.zeros_like(
            mesh.point_data['boundary_mask']
        ).astype(bool)
    
    return mesh

