"""
Mesh generation utilities for TensorGalerkin
"""

import meshio
import numpy as np
import torch
import torch_geometric as pyg
from typing import Optional, Tuple, Union


def mesh_to_pyg_graph(mesh, dtype=torch.float64, selfloop=True, use_dense=True):
    """
    Convert meshio mesh to PyTorch Geometric graph
    
    Parameters:
    -----------
        mesh: meshio.Mesh
            Input mesh
        dtype: torch.dtype
            Data type for tensors
        selfloop: bool
            Whether to add self loops
        use_dense: bool
            Whether to use dense connectivity (all pairs in element) vs loop connectivity
            
    Returns:    
    --------
        data: torch_geometric.data.Data
            Graph representation of the mesh
    """
    if "triangle" in mesh.cells_dict:
        elements = mesh.cells_dict["triangle"]
    elif "quad" in mesh.cells_dict:
        elements = mesh.cells_dict["quad"]
    else:
        raise ValueError(f"Unknown cell type {mesh.cells_dict.keys()}")
    
    elements = torch.from_numpy(elements)
    node_coords = torch.from_numpy(mesh.points.astype(np.float64))
    boundary_mask = torch.from_numpy(mesh.point_data['boundary_mask']).bool()
    boundary_value = torch.zeros_like(boundary_mask).type(dtype)
    
    # Create edge connectivity
    if use_dense:
        # Dense connectivity: connect all pairs of nodes in each element
        dense = torch.vmap(lambda x: torch.stack(torch.meshgrid(x, x), -1).reshape(-1, 2))
        edges = dense(elements).reshape(-1, 2).T
    else:
        # Loop connectivity: connect consecutive nodes in element (forming a loop)
        loop = torch.vmap(lambda x: torch.stack([torch.roll(x, shifts=1, dims=0), x], dim=-1))
        edges = loop(elements).reshape(-1, 2).T
    
    # Create PyG graph
    graph = pyg.data.Data(
        x=node_coords, 
        boundary_mask=boundary_mask,
        boundary_value=boundary_value,
        edge_index=edges
    )
    
    if selfloop:
        graph = pyg.transforms.AddSelfLoops()(graph)
        graph = pyg.transforms.RemoveDuplicatedEdges()(graph)
    
    return graph


class MeshGen:
    """Mesh generation utilities"""
    
    @staticmethod
    def init_mesh(config):
        """
        Initialize a mesh based on configuration
        
        Parameters:
        -----------
            config: object
                Configuration object with mesh parameters
                
        Returns:
        --------
            mesh: meshio.Mesh
                Generated mesh
        """
        # For backward compatibility and simplicity, we'll create a simple rectangular mesh
        # In a full implementation, this would support multiple geometries and mesh generators
        
        # Default parameters
        nx = getattr(config, 'grid', getattr(config, 'nx', 32))
        ny = getattr(config, 'grid', getattr(config, 'ny', 32))
        xlims = getattr(config, 'xlims', [0.0, 1.0])
        ylims = getattr(config, 'ylims', [0.0, 1.0])
        element_type = getattr(config, 'element', 'quad')
        
        if element_type == 'quad':
            return MeshGen._create_quad_mesh(nx, ny, xlims, ylims)
        elif element_type == 'tri':
            return MeshGen._create_tri_mesh(nx, ny, xlims, ylims)
        else:
            raise ValueError(f"Unsupported element type: {element_type}")
    
    @staticmethod
    def _create_quad_mesh(nx: int, ny: int, 
                         xlims: Tuple[float, float] = (0.0, 1.0),
                         ylims: Tuple[float, float] = (0.0, 1.0)) -> meshio.Mesh:
        """Create a structured quadrilateral mesh on rectangle"""
        
        # Create grid points
        x = np.linspace(xlims[0], xlims[1], nx + 1)
        y = np.linspace(ylims[0], ylims[1], ny + 1)
        X, Y = np.meshgrid(x, y)
        
        # Flatten to get all points
        points = np.column_stack([X.flatten(), Y.flatten(), np.zeros(X.size)])
        
        # Create quad elements
        quads = []
        for j in range(ny):
            for i in range(nx):
                # Node indices for this quad (counter-clockwise)
                n0 = j * (nx + 1) + i
                n1 = n0 + 1
                n2 = n1 + (nx + 1)
                n3 = n0 + (nx + 1)
                quads.append([n0, n1, n2, n3])
        
        quads = np.array(quads)
        
        # Identify boundary nodes
        boundary_mask = np.zeros(len(points), dtype=bool)
        boundary_values = np.zeros(len(points))
        
        # Mark boundary nodes (on the edges of the rectangle)
        for i, (px, py, _) in enumerate(points):
            if (np.isclose(px, xlims[0]) or np.isclose(px, xlims[1]) or 
                np.isclose(py, ylims[0]) or np.isclose(py, ylims[1])):
                boundary_mask[i] = True
                boundary_values[i] = 0.0  # Homogeneous Dirichlet BC
        
        # Create mesh
        cells = [("quad", quads)]
        point_data = {
            "boundary_mask": boundary_mask,
            "boundary_value": boundary_values
        }
        
        mesh = meshio.Mesh(points, cells, point_data=point_data)
        return mesh
    
    @staticmethod
    def _create_tri_mesh(nx: int, ny: int,
                        xlims: Tuple[float, float] = (0.0, 1.0),
                        ylims: Tuple[float, float] = (0.0, 1.0)) -> meshio.Mesh:
        """Create a structured triangular mesh on rectangle"""
        
        # Create grid points (same as quad)
        x = np.linspace(xlims[0], xlims[1], nx + 1)
        y = np.linspace(ylims[0], ylims[1], ny + 1)
        X, Y = np.meshgrid(x, y)
        points = np.column_stack([X.flatten(), Y.flatten(), np.zeros(X.size)])
        
        # Create triangular elements by splitting each quad into 2 triangles
        tris = []
        for j in range(ny):
            for i in range(nx):
                # Node indices for the quad
                n0 = j * (nx + 1) + i
                n1 = n0 + 1
                n2 = n1 + (nx + 1)
                n3 = n0 + (nx + 1)
                
                # Split into two triangles
                tris.append([n0, n1, n3])  # Lower triangle
                tris.append([n1, n2, n3])  # Upper triangle
        
        tris = np.array(tris)
        
        # Identify boundary nodes (same as quad)
        boundary_mask = np.zeros(len(points), dtype=bool)
        boundary_values = np.zeros(len(points))
        
        for i, (px, py, _) in enumerate(points):
            if (np.isclose(px, xlims[0]) or np.isclose(px, xlims[1]) or 
                np.isclose(py, ylims[0]) or np.isclose(py, ylims[1])):
                boundary_mask[i] = True
                boundary_values[i] = 0.0
        
        # Create mesh
        cells = [("triangle", tris)]
        point_data = {
            "boundary_mask": boundary_mask,
            "boundary_value": boundary_values
        }
        
        mesh = meshio.Mesh(points, cells, point_data=point_data)
        return mesh
    
    @staticmethod
    def mesh_to_pyg_graph(mesh, **kwargs):
        """Convenience method - wrapper around module function"""
        return mesh_to_pyg_graph(mesh, **kwargs)