"""
Gmsh-based mesh generation for finite element methods.

This module provides mesh generation using the Gmsh library for various 
domain geometries including rectangles, circles, ellipses, and L-shaped domains.
"""

import meshio
import os
import gmsh
import numpy as np
from functools import partial
from typing import Tuple

# Precision for cache file naming
DIGIT = 3


class Gmsh:
    """
    Mesh generator using Gmsh library.
    
    Supports generation of triangular and quadrilateral meshes for various
    geometries. Implements caching to avoid regenerating identical meshes.
    """
    
    cache_path = "./.gmsh_cache"

    @staticmethod
    def gen_rectangle(xlims: Tuple[float, float] = (0., 1.), 
                      ylims: Tuple[float, float] = (0., 1.), 
                      chara_length: float = 0.01, 
                      element: str = "tri", 
                      visualize: bool = False,
                      verbose: bool = True) -> meshio.Mesh:
        """
        Generate mesh for a rectangular domain.
        
        Parameters:
        -----------
            xlims: tuple(float, float)
                The left and right boundaries of the rectangle (default: (0, 1))
            ylims: tuple(float, float)
                The bottom and top boundaries of the rectangle (default: (0, 1))
            chara_length: float
                The characteristic length of the mesh (default: 0.01)
            element: str
                The element type ("tri" or "quad") (default: "tri")
            visualize: bool
                Whether to visualize the mesh (default: False)
            verbose: bool
                Whether to print Gmsh output (default: True)
                
        Returns:
        --------
            mesh: meshio.Mesh
                The generated mesh with 2D points and boundary_mask
        """
        assert element in ["tri", "quad"], f"Unknown element {element}"
        unique_str = f"rectangle_{np.round(xlims[0], DIGIT)}_{np.round(xlims[1], DIGIT)}_{np.round(ylims[0], DIGIT)}_{np.round(ylims[1], DIGIT)}_{np.round(chara_length, DIGIT)}"
        gmsh_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.msh")
        vtk_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.vtk")
        
        if os.path.exists(vtk_path):
            mesh = meshio.read(vtk_path, file_format="vtk")
            mesh.points = mesh.points[:, :2]
            mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
            return mesh
        
        assert xlims[1] > xlims[0] and ylims[1] > ylims[0], \
            f"xlims and ylims must be increasing, but got xlims={xlims}, ylims={ylims}"
        width = xlims[1] - xlims[0]
        height = ylims[1] - ylims[0]
        
        gmsh.initialize()

        if not verbose:
            gmsh.option.setNumber("General.Terminal", 0)

        rectangle = gmsh.model.occ.addRectangle(xlims[0], ylims[0], 0, width, height)
        gmsh.model.occ.synchronize()

        if element == "quad":
            gmsh.model.mesh.setRecombine(2, rectangle)
  
        gmsh.model.mesh.setSize(gmsh.model.getEntities(0), chara_length)
        gmsh.model.mesh.generate(2)

        if visualize:
            gmsh.fltk.run()

        if not os.path.exists(Gmsh.cache_path):
            os.makedirs(Gmsh.cache_path)
        gmsh.write(gmsh_path)

        gmsh.finalize()
        
        mesh = meshio.read(gmsh_path)
        os.remove(gmsh_path)
        
        mesh.point_data['boundary_mask'] = (
            np.isclose(mesh.points[:, 0], xlims[0]) |
            np.isclose(mesh.points[:, 0], xlims[1]) |
            np.isclose(mesh.points[:, 1], ylims[0]) |
            np.isclose(mesh.points[:, 1], ylims[1])
        )
        
        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(np.uint32)
        mesh.points = mesh.points[:, :2]
        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)

        return mesh
    
    @staticmethod  
    def gen_circle(radius: float = 1., 
                   center: Tuple[float, float] = (0., 0.), 
                   chara_length: float = 0.1, 
                   element: str = "tri", 
                   visualize: bool = False,
                   verbose: bool = True) -> meshio.Mesh:
        """
        Generate mesh for a circular domain.
        
        Parameters:
        -----------
            radius: float
                The radius of the circle (default: 1.0)
            center: tuple(float, float)
                The center of the circle (default: (0., 0.))
            chara_length: float
                The characteristic length of the mesh (default: 0.1)
            element: str
                The element type ("tri" or "quad") (default: "tri")
            visualize: bool
                Whether to visualize the mesh (default: False)
            verbose: bool
                Whether to print Gmsh output (default: True)
                
        Returns:
        --------
            mesh: meshio.Mesh
                The generated mesh with 2D points and boundary_mask
        """
        assert element in ["tri", "quad"], f"Unknown element {element}"
        unique_str = f"circle_{np.round(radius, DIGIT)}_{np.round(center[0], DIGIT)}_{np.round(center[1], DIGIT)}_{np.round(chara_length, DIGIT)}"
        gmsh_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.msh")
        vtk_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.vtk")
        
        if os.path.exists(vtk_path):
            mesh = meshio.read(vtk_path, file_format="vtk")
            mesh.points = mesh.points[:, :2]
            mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
            return mesh

        gmsh.initialize()

        if not verbose:
            gmsh.option.setNumber("General.Terminal", 0)

        circle = gmsh.model.occ.addDisk(center[0], center[1], 0, radius, radius)
        gmsh.model.occ.synchronize()

        if element == "quad":
            gmsh.model.mesh.setRecombine(2, circle)

        gmsh.model.mesh.setSize(gmsh.model.getEntities(0), chara_length)
        gmsh.model.mesh.generate(2)

        if visualize:
            gmsh.fltk.run()

        if not os.path.exists(Gmsh.cache_path):
            os.makedirs(Gmsh.cache_path)
        gmsh.write(gmsh_path)

        gmsh.finalize()

        mesh = meshio.read(gmsh_path)
        os.remove(gmsh_path)
       
        x, y = mesh.points[:, 0], mesh.points[:, 1]
        mesh.point_data['boundary_mask'] = np.isclose(
            x * x + y * y, radius, atol=np.sqrt(chara_length) / 10
        ).astype(np.uint32)
        
        mesh.points = mesh.points[:, :2]
        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
        
        return mesh
    
    # Alias for backward compatibility (original had typo "cirlce")
    gen_cirlce = gen_circle
    
    @staticmethod
    def gen_ellipse(radius: Tuple[float, float] = (2., 1.), 
                    center: Tuple[float, float] = (0., 0.), 
                    chara_length: float = 0.1, 
                    element: str = "quad", 
                    visualize: bool = False,
                    verbose: bool = True) -> meshio.Mesh:
        """
        Generate mesh for an elliptical domain.
        
        Parameters:
        -----------
            radius: tuple(float, float)
                The semi-axes of the ellipse (default: (2., 1.))
            center: tuple(float, float)
                The center of the ellipse (default: (0., 0.))
            chara_length: float
                The characteristic length of the mesh (default: 0.1)
            element: str
                The element type ("tri" or "quad") (default: "quad")
            visualize: bool
                Whether to visualize the mesh (default: False)
            verbose: bool
                Whether to print Gmsh output (default: True)
                
        Returns:
        --------
            mesh: meshio.Mesh
                The generated mesh with 2D points and boundary_mask
        """
        assert element in ["tri", "quad"], f"Unknown element {element}"
        unique_str = f"ellipse_{np.round(radius[0], DIGIT)}_{np.round(radius[1], DIGIT)}_{np.round(center[0], DIGIT)}_{np.round(center[1], DIGIT)}_{np.round(chara_length, DIGIT)}"
        gmsh_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.msh")
        vtk_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.vtk")
        
        if os.path.exists(vtk_path):
            mesh = meshio.read(vtk_path, file_format="vtk")
            mesh.points = mesh.points[:, :2]
            mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
            return mesh
            
        gmsh.initialize()

        if not verbose:
            gmsh.option.setNumber("General.Terminal", 0)

        ellipse = gmsh.model.occ.addDisk(center[0], center[1], 0, radius[0], radius[1])
        gmsh.model.occ.synchronize()

        if element == "quad":
            gmsh.model.mesh.setRecombine(2, ellipse)

        gmsh.model.mesh.setSize(gmsh.model.getEntities(0), chara_length)
        gmsh.model.mesh.generate(2)

        if visualize:
            gmsh.fltk.run()

        if not os.path.exists(Gmsh.cache_path):
            os.makedirs(Gmsh.cache_path)
        gmsh.write(gmsh_path)

        gmsh.finalize()

        mesh = meshio.read(gmsh_path)
        os.remove(gmsh_path)

        x, y = mesh.points[:, 0], mesh.points[:, 1]
        mesh.point_data['boundary_mask'] = np.isclose(
            x**2 / radius[0]**2 + y**2 / radius[1]**2, 1, 
            atol=np.sqrt(chara_length) / 10
        ).astype(np.uint32)

        mesh.points = mesh.points[:, :2]
        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
        return mesh
    
    @staticmethod
    def gen_L_shape(xlims: Tuple[float, float] = (0., 1.), 
                    ylims: Tuple[float, float] = (0., 1.), 
                    chara_length: float = 0.1, 
                    element: str = "tri", 
                    visualize: bool = False,
                    verbose: bool = True) -> meshio.Mesh:
        """
        Generate mesh for an L-shaped domain.
        
        The L-shape is created by removing the upper-right quarter from a rectangle.
        
        Parameters:
        -----------
            xlims: tuple(float, float)
                The left and right boundaries (default: (0, 1))
            ylims: tuple(float, float)
                The bottom and top boundaries (default: (0, 1))
            chara_length: float
                The characteristic length of the mesh (default: 0.1)
            element: str
                The element type ("tri" or "quad") (default: "tri")
            visualize: bool
                Whether to visualize the mesh (default: False)
            verbose: bool
                Whether to print Gmsh output (default: True)
                
        Returns:
        --------
            mesh: meshio.Mesh
                The generated mesh with 2D points and boundary_mask
        """
        assert element in ["tri", "quad"], f"Unknown element {element}"
        unique_str = f"L_shape_{np.round(xlims[0], DIGIT)}_{np.round(xlims[1], DIGIT)}_{np.round(ylims[0], DIGIT)}_{np.round(ylims[1], DIGIT)}_{np.round(chara_length, DIGIT)}"
        gmsh_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.msh")
        vtk_path = os.path.join(f"{Gmsh.cache_path}", f"{unique_str}.vtk")
        
        if os.path.exists(vtk_path):
            mesh = meshio.read(vtk_path, file_format="vtk")
            mesh.points = mesh.points[:, :2]
            mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
            return mesh
        
        assert xlims[1] > xlims[0] and ylims[1] > ylims[0], \
            f"xlims and ylims must be increasing, but got xlims={xlims}, ylims={ylims}"
        width = xlims[1] - xlims[0]
        height = ylims[1] - ylims[0]
        
        gmsh.initialize()

        if not verbose:
            gmsh.option.setNumber("General.Terminal", 0)

        rectangle = gmsh.model.occ.addRectangle(xlims[0], ylims[0], 0, width, height)
        rectangle_hole = gmsh.model.occ.addRectangle(
            xlims[0] + width / 2, ylims[0] + height / 2, 0, width / 2, height / 2
        )
        gmsh.model.occ.cut([(2, rectangle)], [(2, rectangle_hole)])
        gmsh.model.occ.synchronize()

        if element == "quad":
            gmsh.model.mesh.setRecombine(2, rectangle)

        gmsh.model.mesh.setSize(gmsh.model.getEntities(0), chara_length)
        gmsh.model.mesh.generate(2)

        if visualize:
            gmsh.fltk.run()

        if not os.path.exists(Gmsh.cache_path):
            os.makedirs(Gmsh.cache_path)

        gmsh.write(gmsh_path)
        gmsh.finalize()

        mesh = meshio.read(gmsh_path)
        os.remove(gmsh_path)

        mesh.point_data['boundary_mask'] = (
            np.isclose(mesh.points[:, 0], xlims[0]) |
            np.isclose(mesh.points[:, 0], xlims[1]) |
            np.isclose(mesh.points[:, 1], ylims[0]) |
            np.isclose(mesh.points[:, 1], ylims[1]) |
            (np.isclose(mesh.points[:, 0], xlims[0] + width / 2) & 
             (mesh.points[:, 1] >= ylims[0] + height / 2)) |
            (np.isclose(mesh.points[:, 1], ylims[0] + height / 2) & 
             (mesh.points[:, 0] >= xlims[0] + width / 2))
        )

        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(np.uint32)
        mesh.points = mesh.points[:, :2]
        mesh.point_data['boundary_mask'] = mesh.point_data['boundary_mask'].astype(bool)
        
        return mesh
    
    # Helper for creating partial functions
    def _get_callable_for_partial(method):
        if hasattr(method, "__func__"):
            return method.__func__
        return method
    
    # Convenience methods for specific element types
    gen_tri_rectangle = partial(_get_callable_for_partial(gen_rectangle), element="tri")
    gen_tri_circle = partial(_get_callable_for_partial(gen_circle), element="tri")
    gen_tri_ellipse = partial(_get_callable_for_partial(gen_ellipse), element="tri")
    gen_quad_rectangle = partial(_get_callable_for_partial(gen_rectangle), element="quad")
    gen_quad_circle = partial(_get_callable_for_partial(gen_circle), element="quad")
    gen_quad_ellipse = partial(_get_callable_for_partial(gen_ellipse), element="quad")
    gen_tri_L_shape = partial(_get_callable_for_partial(gen_L_shape), element="tri")
    gen_quad_L_shape = partial(_get_callable_for_partial(gen_L_shape), element="quad")
    
    # Backward compatibility aliases (original had typo "cirlce")
    gen_quad_cirlce = gen_quad_circle

