import torch

__all__ = ['marching_tetrahedra']

triangle_table = torch.tensor([
    [-1, -1, -1, -1, -1, -1],
    [1, 0, 2, -1, -1, -1],
    [4, 0, 3, -1, -1, -1],
    [1, 4, 2, 1, 3, 4],
    [3, 1, 5, -1, -1, -1],
    [2, 3, 0, 2, 5, 3],
    [1, 4, 0, 1, 5, 4],
    [4, 2, 5, -1, -1, -1],
    [4, 5, 2, -1, -1, -1],
    [4, 1, 0, 4, 5, 1],
    [3, 2, 0, 3, 5, 2],
    [1, 3, 5, -1, -1, -1],
    [4, 1, 2, 4, 3, 1],
    [3, 0, 4, -1, -1, -1],
    [2, 0, 1, -1, -1, -1],
    [-1, -1, -1, -1, -1, -1]
], dtype=torch.long)

num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long)
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long))


def _unbatched_marching_tetrahedra(vertices, tets, sdf, scales):
    """unbatched marching tetrahedra.

    Refer to :func:`marching_tetrahedra`.
    """
    device = vertices.device
    
    # call by chunk
    chunk_size = 32 * 1024 * 1024
    if tets.shape[0] > chunk_size:
        merged_verts = None
        merged_scales = None
        merged_faces = None
        merged_verts_ids = None
        for tet_chunk in torch.chunk(tets, tets.shape[0] // chunk_size + 1):
            torch.cuda.empty_cache()
            verts, verts_scales, faces, verts_ids = _unbatched_marching_tetrahedra(vertices, tet_chunk, sdf, scales)
            
            if merged_verts is None:
                merged_verts = verts
                merged_scales = verts_scales
                merged_faces = faces
                merged_verts_ids = verts_ids
            else:
                all_edges = torch.cat([merged_verts_ids, verts_ids], dim=0)
                unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
                # merge vertices
                unique_verts_0 = torch.zeros((unique_edges.shape[0], 2, 3), dtype=torch.float, device=device)
                unique_verts_1 = torch.zeros((unique_edges.shape[0], 2, 1), dtype=torch.float, device=device)
                unique_verts_0[idx_map[:merged_verts[0].shape[0]]] = merged_verts[0]  
                unique_verts_0[idx_map[merged_verts[0].shape[0]:]] = verts[0]         
                unique_verts_1[idx_map[:merged_verts[1].shape[0]]] = merged_verts[1]  
                unique_verts_1[idx_map[merged_verts[1].shape[0]:]] = verts[1]         
                # merge scales
                unique_scales = torch.zeros((unique_edges.shape[0], 2, 1), dtype=torch.float, device=device)
                unique_scales[idx_map[:merged_verts[0].shape[0]]] = merged_scales     
                unique_scales[idx_map[merged_verts[0].shape[0]:]] = verts_scales      
                
                # merge faces
                unique_faces_0 = idx_map[merged_faces.reshape(-1)].reshape(-1, 3)
                unique_faces_1 = idx_map[faces.reshape(-1) + merged_verts[0].shape[0]].reshape(-1, 3)

                merged_faces = torch.cat([unique_faces_0, unique_faces_1], dim=0)
                merged_verts = (unique_verts_0, unique_verts_1)
                merged_scales = unique_scales
                merged_verts_ids = unique_edges
                torch.cuda.empty_cache()
                
        return merged_verts, merged_scales, merged_faces, merged_verts_ids
        
    with torch.no_grad():
        occ_n = sdf > 0
        occ_fx4 = occ_n[tets.reshape(-1)].reshape(-1, 4)
        occ_sum = torch.sum(occ_fx4, -1)
        
        valid_tets = (occ_sum > 0) & (occ_sum < 4)
        
        # find all vertices
        all_edges = tets[valid_tets][:, base_tet_edges.to(device)].reshape(-1, 2)
        
        order = (all_edges[:, 0] > all_edges[:, 1]).bool()
        all_edges[order] = all_edges[order][:, [1, 0]]
        
        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
        
        unique_edges = unique_edges.long()
        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=device) * -1
        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=device)
        idx_map = mapping[idx_map]

        interp_v = unique_edges[mask_edges]
    edges_to_interp = vertices[interp_v.reshape(-1)].reshape(-1, 2, 3)
    edges_to_interp_sdf = sdf[interp_v.reshape(-1)].reshape(-1, 2, 1)
    verts_scales = scales[interp_v.reshape(-1)].reshape(-1, 2, 1)
    
    verts = (edges_to_interp, edges_to_interp_sdf)
    idx_map = idx_map.reshape(-1, 6)

    tetindex = (occ_fx4[valid_tets] * v_id.to(device).unsqueeze(0)).sum(-1)
    num_triangles = num_triangles_table.to(device)[tetindex]
    triangle_table_device = triangle_table.to(device)

    # Generate triangle indices
    faces = torch.cat((
        torch.gather(input=idx_map[num_triangles == 1], dim=1,
                     index=triangle_table_device[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
        torch.gather(input=idx_map[num_triangles == 2], dim=1,
                     index=triangle_table_device[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
    ), dim=0)

    return verts, verts_scales, faces, interp_v


def marching_tetrahedra(vertices, tets, sdf, scales):
    r"""Convert discrete signed distance fields encoded on tetrahedral grids to triangle
    meshes using marching tetrahedra algorithm as described in `An efficient method of
    triangulating equi-valued surfaces by using tetrahedral cells`_. The output surface is differentiable with respect to
    input vertex positions and the SDF values.


    Args:
        vertices (torch.tensor): batched vertices of tetrahedral meshes, of shape
                                 :math:`(\text{batch_size}, \text{num_vertices}, 3)`.
        tets (torch.tensor): unbatched tetrahedral mesh topology, of shape
                             :math:`(\text{num_tetrahedrons}, 4)`.
        sdf (torch.tensor): batched SDFs which specify the SDF value of each vertex, of shape
                            :math:`(\text{batch_size}, \text{num_vertices})`.

    Returns:
        (list[torch.Tensor], list[torch.LongTensor], (optional) list[torch.LongTensor]): 

            - the list of vertices for mesh converted from each tetrahedral grid.
            - the list of faces for mesh converted from each tetrahedral grid.

    Example:
        >>> vertices = torch.tensor([[[0, 0, 0],
        ...               [1, 0, 0],
        ...               [0, 1, 0],
        ...               [0, 0, 1]]], dtype=torch.float)
        >>> tets = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
        >>> sdf = torch.tensor([[-1., -1., 0.5, 0.5]], dtype=torch.float)
        >>> verts_list, faces_list, tet_idx_list = marching_tetrahedra(vertices, tets, sdf, True)
        >>> verts_list[0]
        tensor([[0.0000, 0.6667, 0.0000],
                [0.0000, 0.0000, 0.6667],
                [0.3333, 0.6667, 0.0000],
                [0.3333, 0.0000, 0.6667]])
        >>> faces_list[0]
        tensor([[3, 0, 1],
                [3, 2, 0]])
        >>> tet_idx_list[0]
        tensor([0, 0])


    """

    list_of_outputs = [_unbatched_marching_tetrahedra(vertices[b], tets, sdf[b], scales[b]) for b in range(vertices.shape[0])]
    return list(zip(*list_of_outputs))