import os
import gmsh
import meshio
from dolfin import *
import numpy as np


def compute_dof_map_ns2d(W):
    """
    Compute dof map for u0, u1, p
    
    Parameters
    ----------
    W : dolfin.FunctionSpace
        Function space for the mixed space
        [Vector[u0, u1], p]
    
    Returns
    -------
    select_p_map, reorder_p_map,
        select_u0_map, reorder_u0_map, 
        select_u1_map, reorder_u1_map
    select_*_map : numpy.ndarray
        Boolean array which selects the dofs for *
    reorder_*_map : numpy.ndarray
        Array which reorders the * to the dof order
    """
    W_u = W.sub(0).collapse()
    W_u0 = W_u.sub(0).collapse()
    W_u1 = W_u.sub(1).collapse()
    W_p = W.sub(1).collapse()

    # Compute dof map for u0
    # Set u0
    u0 = Function(W_u0)
    u0.vector()[:] = np.array([
        i+1 for i in range(len(u0.vector()[:]))
    ])

    # Assign u0 to w
    w = Function(W)
    assign(w.sub(0).sub(0), u0)

    # Extract dof map
    w_vec = w.vector()[:]
    select_u0_map = np.where(w_vec > 0)[0]
    # Compute the map which reorders u0 to w_vec[select_u0_map]
    reorder_u0_map = np.argsort(w_vec[select_u0_map])
    # Take the reverse index
    reorder_u0_map = np.argsort(reorder_u0_map)

    # Compute dof map for u1
    # Set u1
    u1 = Function(W_u1)
    u1.vector()[:] = np.array([
        i+1 for i in range(len(u1.vector()[:]))
    ])
    # Assign u1 to w
    w = Function(W)
    assign(w.sub(0).sub(1), u1)
    # Extract dof map
    w_vec = w.vector()[:]
    select_u1_map = np.where(w_vec > 0)[0]
    # Compute the map which reorders u1 to w_vec[select_u1_map]
    reorder_u1_map = np.argsort(w_vec[select_u1_map])
    # Take the reverse index
    reorder_u1_map = np.argsort(reorder_u1_map)

    # Compute dof map for p
    # Set p
    p = Function(W_p)
    p.vector()[:] = np.array([
        i+1 for i in range(len(p.vector()[:]))
    ])
    # Assign p to w
    w = Function(W)
    assign(w.sub(1), p)
    # Extract dof map
    w_vec = w.vector()[:]
    select_p_map = np.where(w_vec > 0)[0]
    # Compute the map which reorders p to w_vec[select_p_map]
    reorder_p_map = np.argsort(w_vec[select_p_map])
    # Take the reverse index
    reorder_p_map = np.argsort(reorder_p_map)

    return select_p_map, reorder_p_map, \
        select_u0_map, reorder_u0_map, select_u1_map, reorder_u1_map


def compute_dof_map_burgers2d(W):
    """
    Compute dof map for u0, u1
    
    Parameters
    ----------
    W : dolfin.FunctionSpace
        Function space for the mixed space
        Vector[u0, u1]
    
    Returns
    -------
    select_u0_map, reorder_u0_map, 
        select_u1_map, reorder_u1_map
    select_*_map : numpy.ndarray
        Boolean array which selects the dofs for *
    reorder_*_map : numpy.ndarray
        Array which reorders the * to the dof order
    """
    W_u0 = W.sub(0).collapse()
    W_u1 = W.sub(1).collapse()

    # Compute dof map for u0
    # Set u0
    u0 = Function(W_u0)
    u0.vector()[:] = np.arange(1, len(u0.vector()[:]) + 1)

    # Assign u0 to w
    w = Function(W)
    assign(w.sub(0), u0)
    # Extract dof map
    w_vec = w.vector()[:]
    select_u0_map = np.where(w_vec > 0)[0]
    # Compute the map which reorders u0 to w_vec[select_u0_map]
    reorder_u0_map = np.searchsorted(
        u0.vector()[:], w_vec[select_u0_map])

    # Compute dof map for u1
    # Set u1
    u1 = Function(W_u1)
    u1.vector()[:] = np.arange(1, len(u1.vector()[:]) + 1)
    # Assign u1 to w
    w = Function(W)
    assign(w.sub(1), u1)
    # Extract dof map
    w_vec = w.vector()[:]
    select_u1_map = np.where(w_vec > 0)[0]
    # Compute the map which reorders u0 to w_vec[select_u1_map]
    reorder_u1_map = np.searchsorted(
        u1.vector()[:], w_vec[select_u1_map])

    return select_u0_map, reorder_u0_map, \
        select_u1_map, reorder_u1_map


def convert_mesh_file(file_path, dim=2):
    """
    Convert mesh file from `.msh` to `.xdmf` format.

    Parameters
    ----------
    file_path : str
        Path to the mesh file
    dim : int
        Dimension of the mesh
    """
    if not file_path.endswith(".msh"):
        raise ValueError("Only .msh file is supported")

    msh = meshio.read(file_path)
    for cell in msh.cells:
        if cell.type == "triangle":
            triangle_cells = cell.data
        elif cell.type == "tetra":
            tetra_cells = cell.data
    for key in msh.cell_data_dict["gmsh:physical"].keys():
        if key == "triangle":
            triangle_data = msh.cell_data_dict["gmsh:physical"][key]
        elif key == "tetra":
            tetra_data = msh.cell_data_dict["gmsh:physical"][key]

    if dim == 2:
        mesh = meshio.Mesh(
            points=msh.points, cells=[("triangle", triangle_cells)], cell_data={"name_to_read": [triangle_data]}
        )

        # remove the z coordinate
        mesh.points = mesh.points[:, :2]
    elif dim == 3:
        mesh = meshio.Mesh(points=msh.points, cells={"tetra": tetra_cells}, cell_data={"name_to_read": [tetra_data]})

    file_path = file_path.replace(".msh", ".xdmf")
    meshio.write(file_path, mesh)

    mesh = Mesh()
    with XDMFFile(MPI.comm_world, file_path) as infile:
        infile.read(mesh)
    return mesh


def generate_complex_mesh2d(bbox, voids, mesh_length):
    """
    Generate 2D complex mesh with rectangular or circler voids.

    Parameters
    ----------
    bbox : list
        Bounding box of the (outer rectangular) domain
    voids : dict
        Dictionary of voids. The key is the name of the void, and the value is
        a list of voids with the same name. Each void is a list of numbers
        representing the shape of the void.
        Available voids: "circ": circle (x, y, r),
        "rec": rectangle (x1, x2, y1, y2) (i.e., bounding box)
    mesh_length : float
        Maximum mesh length
    
    Returns
    -------
    mesh : dolfin.Mesh
        The generated mesh
    """
    # Initialize gmsh
    gmsh.initialize()
    gmsh.model.add("model")

    # Create rectangle
    rec = gmsh.model.occ.addRectangle(bbox[0], bbox[2], 0, bbox[1] - bbox[0], bbox[3] - bbox[2])

    # Subtract obstacles from the rectangle
    for name, configs in voids.items():
        if name == "circ":
            for c in configs:
                circle = gmsh.model.occ.addDisk(c[0], c[1], 0, c[2], c[2])
                gmsh.model.occ.cut([(2, rec)], [(2, circle)])
        elif name == "rec":
            for c in configs:
                rect = gmsh.model.occ.addRectangle(c[0], c[2], 0, c[1] - c[0], c[3] - c[2])
                gmsh.model.occ.cut([(2, rec)], [(2, rect)])

    # Synchronize the model and generate 2D mesh
    gmsh.model.occ.synchronize()
    gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_length)
    gmsh.model.mesh.generate(2)

    # Write the mesh to file
    tmp_filename = "temp" + str(os.getpid())
    gmsh.option.setNumber("Mesh.MshFileVersion", 2)
    gmsh.write(f"temp/{tmp_filename}.msh")

    # Finalize gmsh
    gmsh.finalize()

    mesh = convert_mesh_file(f"temp/{tmp_filename}.msh", dim=2)
    
    # clear temp folder
    os.remove(f"temp/{tmp_filename}.msh")
    os.remove(f"temp/{tmp_filename}.xdmf")
    os.remove(f"temp/{tmp_filename}.h5")

    return mesh


def generate_complex_mesh3d(bbox, voids, mesh_length):
    """
    Generate 3D complex mesh with box or sphere voids.

    Parameters
    ----------
    bbox : list
        Bounding box of the (outer) domain
    voids : dict
        Dictionary of voids. The key is the name of the void, and the value is
        a list of voids with the same name. Each void is a list of numbers
        representing the shape of the void.
        Available voids: "sphere": sphere (x, y, z, r),
        "box": box (x1, x2, y1, y2, z1, z2) (i.e., bounding box),
        "cylinder": cylinder (x, y, z1, z2, r)
    mesh_length : float
        Maximum mesh length
    
    Returns
    -------
    mesh : dolfin.Mesh
        The generated mesh
    """
    # Initialize gmsh
    gmsh.initialize()
    gmsh.model.add("model")

    # Create rectangle
    cube = gmsh.model.occ.addBox(bbox[0], bbox[2], 
        bbox[4], bbox[1] - bbox[0], bbox[3] - bbox[2], bbox[5] - bbox[4])

    # Subtract obstacles from the cube
    for name, configs in voids.items():
        if name == "sphere":
            for c in configs:
                sphere = gmsh.model.occ.addSphere(c[0], c[1], c[2], c[3])
                gmsh.model.occ.cut([(3, cube)], [(3, sphere)])
        elif name == "box":
            for c in configs:
                cube_ = gmsh.model.occ.addBox(c[0], c[2], 
                    c[4], c[1] - c[0], c[3] - c[2], c[5] - c[4])
                gmsh.model.occ.cut([(3, cube)], [(3, cube_)])
        elif name == "cylinder":
            for c in configs:
                cylinder = gmsh.model.occ.addCylinder(
                    c[0], c[1], c[2], 0, 0, c[3] - c[2], c[4])
                gmsh.model.occ.cut([(3, cube)], [(3, cylinder)])

    # Synchronize the model and generate 3D mesh
    gmsh.model.occ.synchronize()
    gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_length)
    gmsh.model.mesh.generate(3)

    # Write the mesh to file
    tmp_filename = "temp" + str(os.getpid())
    gmsh.option.setNumber("Mesh.MshFileVersion", 2)
    gmsh.write(f"temp/{tmp_filename}.msh")

    # Finalize gmsh
    gmsh.finalize()

    mesh = convert_mesh_file(f"temp/{tmp_filename}.msh", dim=3)
    
    # clear temp folder
    os.remove(f"temp/{tmp_filename}.msh")
    os.remove(f"temp/{tmp_filename}.xdmf")
    os.remove(f"temp/{tmp_filename}.h5")

    return mesh
