import os
import numpy as np
import matplotlib.pyplot as plt
import trimesh
import networkx as nx
import open3d as o3d
import gdist 
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import shortest_path
from scipy.io import loadmat, savemat
from sklearn import neighbors
import hdf5storage

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
#import debugpy; debugpy.listen(('127.0.0.1', 57004)); debugpy.wait_for_client()


def compute_geodesic_distmat(verts, faces):
    """
    Compute geodesic distance matrix using Dijkstra algorithm

    Args:
        verts (np.ndarray): array of vertices coordinates [n, 3]
        faces (np.ndarray): array of triangular faces [m, 3]

    Returns:
        geo_dist: geodesic distance matrix [n, n]
    """
    NN = 500

    # get adjacency matrix
    #使用 trimesh 创建一个网格对象，包含给定的顶点和面片
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
    # 获取网格顶点的邻接图
    vertex_adjacency = mesh.vertex_adjacency_graph
    # 确保图是连通的，否则抛出断言错误
    assert nx.is_connected(vertex_adjacency), 'Graph not connected'
    # 获取顶点图的邻接矩阵
    vertex_adjacency_matrix = nx.adjacency_matrix(vertex_adjacency, range(verts.shape[0]))
    # get adjacency distance matrix
    # 基于顶点距离创建一个 k-最近邻图
    graph_x_csr = neighbors.kneighbors_graph(verts, n_neighbors=NN, mode='distance', include_self=False)
    # 初始化一个稀疏矩阵用于邻接距离
    distance_adj = csr_matrix((verts.shape[0], verts.shape[0])).tolil()
    # 更新距离邻接矩阵，在顶点邻接矩阵有连接的地方
    distance_adj[vertex_adjacency_matrix != 0] = graph_x_csr[vertex_adjacency_matrix != 0]
    # compute geodesic matrix
    # 使用最短路径算法计算测地距离矩阵
    geodesic_x = shortest_path(distance_adj, directed=False)
    if np.any(np.isinf(geodesic_x)):
        print('Inf number in geodesic distance. Increase NN.')
    # 确保矩阵对称
    geodesic_x = 0.5 * (geodesic_x + geodesic_x.T)
    return geodesic_x

# def compute_geodesic_distmat1(verts, triv):
#     """
#     计算给定 mesh 的测地距离矩阵。

#     :param verts: numpy array of shape (N, 3), 顶点坐标。
#     :param triv: numpy array of shape (M, 3), 三角面索引（从0开始）。
#     :return: numpy array of shape (N, N), 测地距离矩阵。
#     """
#     num_verts = verts.shape[0]
#     dist_matrix = np.zeros((num_verts, num_verts), dtype=np.float64)

#     for i in range(num_verts):
#         source_indices = np.array([i], dtype=np.int32)
#         distances = gdist.compute_gdist(
#             np.ascontiguousarray(verts, dtype=np.float64),
#             np.ascontiguousarray(triv, dtype=np.int32),
#             source_indices
#         )
#         dist_matrix[i, :] = distances

#     return dist_matrix



def compute_dist_matrices(shapes_dir):
    """
    Compute distance matrices for all .mat files in the given directory.

    :param shapes_dir: Directory containing .mat files with shape data.
    """
    # 获取所有 .mat 文件
    files = [f for f in os.listdir(shapes_dir) if f.endswith('.mat')]
    matrices_dir = os.path.join(shapes_dir, "distance_matrix")
    os.makedirs(matrices_dir, exist_ok=True)

    for i, file_name in enumerate(files):
        print(f"Processing {i + 1} of {len(files)}")

        # 检查输出文件是否已存在
        output_file = os.path.join(matrices_dir, file_name)
        if os.path.exists(output_file):
            continue

        # 加载形状数据
        file_path = os.path.join(shapes_dir, file_name)
        data = loadmat(file_path)
        verts = data['X']['vert'][0, 0]
        triv = data['X']['triv'][0, 0]

        # 计算距离矩阵
        D = compute_geodesic_distmat(verts,triv)
        D = D.astype(np.float32)  # 转换为单精度浮点数
        
        # 保存距离矩阵
        hdf5storage.savemat(output_file, {'D': D}, format='7.3', store_python_metadata=False)
        # savemat(output_file, {'D': D},format="7.3")

if __name__ == "__main__":
    compute_dist_matrices("/workspace/projects/Frosting/neuromorph/data/meshes/mushroom/mat")
    

