# using the modelnet40 as the dataset, and using the processed feature matrixes
import json
import random
from collections import deque
from pathlib import Path
import numpy as np
import torch.utils.data as data
import trimesh
from scipy.spatial.transform import Rotation
from scipy.stats import stats
from sklearn.preprocessing import RobustScaler
# from model import match
from model import match_obj as match
from pygem import FFD
import copy
import torch


def check(path, scalars_data, mesh):
    short_path = '/root/autodl-tmp/dataset/cached/' + path.name[:-4] + '.vtk'
    scalar_keys = {
        'corrThickness_MSMAll': 'corrThickness_MSMAll',
        'MyelinMap_MSMAll': 'MyelinMap_MSMAll',
        'sulc_MSMAll': 'sulc_MSMAll',
        'thickness_MSMAll': 'thickness_MSMAll',
        'aparc': 'aparc'  # 确保键名一致
    }
    new_scalars_data = {}
    Flag = False
    length = mesh.vertices.shape[0]

    for key, scalar_name in scalar_keys.items():
        # 检查键是否存在且大小是否匹配
        if key not in scalars_data or np.array(scalars_data[key]).size != length:
            try:
                scalars_data[key] = match.find_scalars(path, mesh, scalar_name, True)
                Flag = True
            except Exception as e:
                print(f"Error finding scalar '{scalar_name}': {e}")
                # 根据需求处理异常，例如继续或返回错误

    for key, scalar_name in scalar_keys.items():
        new_scalars_data[key] = np.array(scalars_data[key])
    if Flag:
        try:
            match.write_vtk_file_v3(mesh.vertices, mesh.faces, new_scalars_data, short_path)
        except Exception as e:
            print(f"Error writing VTK file '{short_path}': {e}")  # 根据需求处理异常

    return scalars_data


def randomize_mesh_orientation(mesh: trimesh.Trimesh):
    mesh1 = copy.deepcopy(mesh)
    axis_seq = ''.join(random.sample('xyz', 3))
    angles = [random.choice([0, 90, 180, 270]) for _ in range(3)]
    rotation = Rotation.from_euler(axis_seq, angles, degrees=True)
    mesh1.vertices = rotation.apply(mesh1.vertices)
    return mesh1


def random_scale(mesh: trimesh.Trimesh):
    mesh.vertices = mesh.vertices * np.random.normal(1, 0.1, size=(1, 3))
    return mesh


def mesh_normalize(mesh: trimesh.Trimesh):
    mesh1 = copy.deepcopy(mesh)
    vertices = mesh1.vertices - mesh1.vertices.min(axis=0)
    vertices = vertices / vertices.max()
    mesh1.vertices = vertices
    return mesh1


def mesh_deformation(mesh: trimesh.Trimesh):
    ffd = FFD([2, 2, 2])
    random = np.random.rand(6) * 0.1
    ffd.array_mu_x[1, 1, 1] = random[0]
    ffd.array_mu_y[1, 1, 1] = random[1]
    ffd.array_mu_z[1, 1, 1] = random[2]
    ffd.array_mu_x[0, 0, 0] = random[3]
    ffd.array_mu_y[0, 0, 0] = random[4]
    ffd.array_mu_z[0, 0, 0] = random[5]
    vertices = mesh.vertices
    new_vertices = ffd(vertices)
    mesh.vertices = new_vertices
    return mesh


def load_mesh(path, augments=[], request=[], seed=None):
    label = 0
    if 'L_wm' in str(path):
        label = 0
    else:
        label = 1

    mesh = trimesh.load_mesh(path, process=False)

    for method in augments:
        if method == 'orient':
            mesh = randomize_mesh_orientation(mesh)
        if method == 'scale':
            mesh = random_scale(mesh)
        if method == 'deformation':
            mesh = mesh_deformation(mesh)

    F = mesh.faces
    V = mesh.vertices

    Fs = mesh.faces.shape[0]
    face_coordinate = V[F.flatten()].reshape(-1, 9)

    face_center = V[F.flatten()].reshape(-1, 3, 3).mean(axis=1)
    vertex_normals = mesh.vertex_normals
    face_normals = mesh.face_normals
    face_curvs = np.vstack([
        (vertex_normals[F[:, 0]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 1]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 2]] * face_normals).sum(axis=1),
    ])

    feats = []
    if 'area' in request:
        feats.append(mesh.area_faces)
    if 'normal' in request:
        feats.append(face_normals.T)
    if 'center' in request:
        feats.append(face_center.T)
    if 'face_angles' in request:
        feats.append(np.sort(mesh.face_angles, axis=1).T)
    if 'curvs' in request:
        feats.append(np.sort(face_curvs, axis=0))

    feats = np.vstack(feats)
    patch_num = Fs // 4 // 4 // 4
    allindex = np.array(list(range(0, Fs)))
    indices = allindex.reshape(-1, patch_num).transpose(1, 0)

    feats_patch = feats[:, indices]
    center_patch = face_center[indices]
    cordinates_patch = face_coordinate[indices]
    faces_patch = mesh.faces[indices]

    feats_patch = feats_patch
    center_patch = center_patch
    cordinates_patch = cordinates_patch
    faces_patch = faces_patch
    feats_patcha = np.concatenate((feats_patch, np.zeros((10, 256 - patch_num, 64), dtype=np.float32)), 1)
    center_patcha = np.concatenate((center_patch, np.zeros((256 - patch_num, 64, 3), dtype=np.float32)), 0)
    cordinates_patcha = np.concatenate((cordinates_patch, np.zeros((256 - patch_num, 64, 9), dtype=np.float32)), 0)
    faces_patcha = np.concatenate((faces_patch, np.zeros((256 - patch_num, 64, 3), dtype=int)), 0)

    Fs_patcha = np.array(Fs)

    return feats_patcha, center_patcha, cordinates_patcha, faces_patcha, Fs_patcha, label


def load_mesh_seg(path, normalize=True, augments=[], request=[], seed=None):
    mesh = trimesh.load_mesh(path, process=False)
    label_path = Path(str(path).split('.')[0] + '.json')
    with open(label_path) as f:
        segment = json.load(f)

    sub_labels = np.array(segment['sub_labels']) - 1

    for method in augments:
        if method == 'orient':
            mesh = randomize_mesh_orientation(mesh)
        if method == 'scale':
            mesh = random_scale(mesh)
        if method == 'deformation':
            mesh = mesh_deformation(mesh)
    if normalize:
        mesh = mesh_normalize(mesh)
    F = mesh.faces
    V = mesh.vertices
    Fs = mesh.faces.shape[0]
    face_coordinate = V[F.flatten()].reshape(-1, 9)

    face_center = V[F.flatten()].reshape(-1, 3, 3).mean(axis=1)
    vertex_normals = mesh.vertex_normals
    face_normals = mesh.face_normals
    face_curvs = np.vstack([
        (vertex_normals[F[:, 0]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 1]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 2]] * face_normals).sum(axis=1),
    ])

    feats = []
    if 'area' in request:
        feats.append(mesh.area_faces)
    if 'normal' in request:
        feats.append(face_normals.T)
    if 'center' in request:
        feats.append(face_center.T)
    if 'face_angles' in request:
        feats.append(np.sort(mesh.face_angles, axis=1).T)
    if 'curvs' in request:
        feats.append(np.sort(face_curvs, axis=0))

    feats = np.vstack(feats)
    patch_num = Fs // 4 // 4 // 4
    if patch_num != 256:
        print(path)
    allindex = np.array(list(range(0, Fs)))
    indices = allindex.reshape(-1, patch_num).transpose(1, 0)

    feats_patch = feats[:, indices]
    center_patch = face_center[indices]
    cordinates_patch = face_coordinate[indices]
    faces_patch = mesh.faces[indices]
    label_patch = sub_labels[indices]
    label_patcha = np.concatenate((label_patch, np.zeros((256 - patch_num, 64), dtype=np.float32)), 0)
    label_patcha = np.expand_dims(label_patcha, axis=2)
    feats_patch = feats_patch
    center_patch = center_patch
    cordinates_patch = cordinates_patch
    faces_patch = faces_patch
    feats_patcha = np.concatenate((feats_patch, np.zeros((10, 256 - patch_num, 64), dtype=np.float32)), 1)
    center_patcha = np.concatenate((center_patch, np.zeros((256 - patch_num, 64, 3), dtype=np.float32)), 0)
    cordinates_patcha = np.concatenate((cordinates_patch, np.zeros((256 - patch_num, 64, 9), dtype=np.float32)), 0)
    faces_patcha = np.concatenate((faces_patch, np.zeros((256 - patch_num, 64, 3), dtype=int)), 0)
    feats_patcha = feats_patcha.transpose(1, 2, 0)
    Fs_patcha = np.array(Fs)
    Fs_patcha = Fs_patcha.repeat(256 * 64).reshape(256, 64, 1)

    return faces_patcha, feats_patcha, Fs_patcha, center_patcha, cordinates_patcha, label_patcha


def get_face_k_ring(mesh, start_face_idx, k=1):
    """
    使用BFS在面邻接图上找到一个面的k-ring邻域
    """
    q = deque([(start_face_idx, 0)])
    visited = {start_face_idx}
    k_ring_neighbors = []

    # mesh.face_adjacency 是一个 (M, 2) 的数组，记录了共享边的面
    # mesh.face_adjacency_edges 是对应的共享边的顶点索引
    # 我们需要一个更方便的图结构
    # 只需计算一次
    if not hasattr(mesh, '_face_graph'):
        graph = [[] for _ in range(len(mesh.faces))]
        for i, j in mesh.face_adjacency:
            graph[i].append(j)
            graph[j].append(i)
        mesh._face_graph = graph

    graph = mesh._face_graph

    while q:
        current_face, depth = q.popleft()

        if depth >= k:
            continue

        for neighbor in graph[current_face]:
            if neighbor not in visited:
                visited.add(neighbor)
                k_ring_neighbors.append(neighbor)
                q.append((neighbor, depth + 1))

    return k_ring_neighbors

def load_mesh_shape(path, augments=[], request=[], seed=None):
    # 这段代码的目的是：
    # 从网格数据中提取点和面的特征并计算后面要用的几何特征。
    # 用0填充提取出来的数据，使其大小统一为一个固定的值（如 256 个 patch）。
    # 返回填充后的特征、几何数据和面索引
    mesh = trimesh.load_mesh(path, process=False)
    # scalars_data = match.find_scalars(path, mesh)
    # points, faces, scalars_data = read_vtk_mesh(path)

    for method in augments:
        if method == 'orient':
            mesh = randomize_mesh_orientation(mesh)
        if method == 'scale':
            mesh = random_scale(mesh)
        if method == 'deformation':
            mesh = mesh_deformation(mesh)

    F = mesh.faces
    V = mesh.vertices

    Fs = mesh.faces.shape[0]
    # 提取出每个面对应的9个空间坐标(3*3)
    face_coordinate = V[F.flatten()].reshape(-1, 9)

    face_center = V[F.flatten()].reshape(-1, 3, 3).mean(axis=1)
    vertex_normals = mesh.vertex_normals
    face_normals = mesh.face_normals
    face_curvs = np.vstack([
        (vertex_normals[F[:, 0]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 1]] * face_normals).sum(axis=1),
        (vertex_normals[F[:, 2]] * face_normals).sum(axis=1),
    ])

    # face_graph = [[] for _ in range(len(mesh.faces))]
    # for i, j in mesh.face_adjacency:
    #     face_graph[i].append(j)
    #     face_graph[j].append(i)
    # # --- 新增：计算多尺度特征 (1-ring 邻域) ---
    # num_faces = len(mesh.faces)
    # neighbor_area_std = np.zeros(num_faces, dtype=np.float32)
    # neighbor_normal_dispersion = np.zeros(num_faces, dtype=np.float32)
    #
    # # 获取所有面片的面积，避免在循环中重复访问
    # all_face_areas = mesh.area_faces
    #
    # # 遍历每个面片来计算其邻域特征
    # for i in range(num_faces):
    #     # 获取 1-ring 邻域的面片索引
    #     neighbors_idx = face_graph[i]
    #
    #     if not neighbors_idx:  # 如果一个面片没有邻居（例如，一个孤立的三角形）
    #         neighbor_area_std[i] = 0.0
    #         neighbor_normal_dispersion[i] = 0.0
    #         continue
    #
    #     # 1. 计算邻域面积的离散程度 (标准差)
    #     neighbor_areas = all_face_areas[neighbors_idx]
    #     neighbor_area_std[i] = np.std(neighbor_areas)
    #
    #     # 2. 计算邻域法线的离散程度
    #     neighbor_normals = face_normals[neighbors_idx]
    #
    #     # 计算平均法向量
    #     mean_neighbor_normal = np.mean(neighbor_normals, axis=0)
    #     norm_of_mean = np.linalg.norm(mean_neighbor_normal)
    #
    #     if norm_of_mean < 1e-6:  # 避免除以零，例如法向量对称抵消
    #         # 这种情况意味着法线分布非常不规则，给予最大离散度
    #         neighbor_normal_dispersion[i] = 1.0
    #         continue
    #
    #     mean_neighbor_normal /= norm_of_mean  # 归一化
    #
    #     # 计算每个邻域法向量与平均法向量的点积
    #     # 为了广播，需要 reshape mean_neighbor_normal
    #     dot_products = np.sum(neighbor_normals * mean_neighbor_normal, axis=1)
    #
    #     # 最终的离散度特征是 1.0 - 平均点积
    #     # 我们用 clip 确保值在 [0, 2] 范围内 (点积可能为负)
    #     avg_dot = np.mean(dot_products)
    #     neighbor_normal_dispersion[i] = 1.0 - avg_dot

    # print(scalars_data.keys())
    # print(path.name)
    # 确保 F.flatten() 中的值是整数类型，并检查是否包含 NaN 或 inf
    # F_flattened = F.flatten()
    #
    # # 确保 F.flatten() 转换为整数
    # F_flattened = F_flattened.astype(int)

    # thickness = np.array(scalars_data['thickness_MSMAll'])[F_flattened].reshape(-1, 3)

    # scalars_data = check(path, scalars_data, mesh)

    #     scaler = RobustScaler()

    #     thickness_data = np.array(scalars_data['corrThickness_MSMAll']).reshape(-1, 1)
    #     thickness_data_robust = scaler.fit_transform(thickness_data)
    #     corr_Thickness = thickness_data_robust[F_flattened].reshape(3, -1)

    #     MyelinMap_data = np.array(scalars_data['MyelinMap_MSMAll']).reshape(-1, 1)
    #     MyelinMap_data_robust = scaler.fit_transform(MyelinMap_data)
    #     MyelinMap = MyelinMap_data_robust[F_flattened].reshape(3, -1)

    #     sulc_data = np.array(scalars_data['sulc_MSMAll']).reshape(-1, 1)
    #     sulc_data_robust = scaler.fit_transform(sulc_data)
    #     sulc = sulc_data_robust[F_flattened].reshape(3, -1)

    #     thickness_data = np.array(scalars_data['thickness_MSMAll']).reshape(-1, 1)
    #     thickness_data_robust = scaler.fit_transform(thickness_data)
    #     thickness = thickness_data_robust[F_flattened].reshape(3, -1)

    # aparc = np.array(scalars_data['aparc'])[F_flattened].reshape(-1, 3).mean(axis=1)
    # 怎么改？这里把Scalar Feature加上
    # feats = [np.sort(corr_Thickness, axis=0), np.sort(MyelinMap, axis=0), np.sort(sulc, axis=0), np.sort(thickness, axis=0)]

    #     corr_Thickness = np.array(scalars_data['corrThickness_MSMAll'])[F_flattened].reshape(-1, 3).mean(axis=1)
    #     MyelinMap = np.array(scalars_data['MyelinMap_MSMAll'])[F_flattened].reshape(-1, 3).mean(axis=1)
    #     sulc = np.array(scalars_data['sulc_MSMAll'])[F_flattened].reshape(-1, 3).mean(axis=1)
    # thickness = np.array(scalars_data['thickness_MSMAll'])[F_flattened].reshape(-1, 3)
    #     aparc = np.array(scalars_data['aparc'])[F_flattened].reshape(-1, 3).mean(axis=1)

    #     # 怎么改？这里把Scalar Feature加上
    feats = []
    if 'area' in request:
        feats.append(mesh.area_faces)
    if 'normal' in request:
        feats.append(face_normals.T)
    if 'center' in request:
        feats.append(face_center.T)
    if 'face_angles' in request:
        feats.append(np.sort(mesh.face_angles, axis=1).T)
    if 'curvs' in request:
        feats.append(np.sort(face_curvs, axis=0))

    feats = np.vstack(feats)
    patch_size = 64  # 从您的计算中推断出patch大小为64
    patch_num = Fs // patch_size
    # 下面两行生成Patch的index加生成每个patch里面face的index
    allindex = np.array(list(range(0, Fs)))
    # allindex.reshape(-1, patch_num)将indices的尺寸变为Face_num * Patch_num的矩阵
    # .transpose(1, 0)则将矩阵转置，将矩阵变为Patch_num * Face_num的矩阵
    # 此处合理怀疑，maps算法将网格中的面按顶点
    indices = allindex.reshape(-1, patch_num).transpose(1, 0)

    # feats[:, indices] 的含义是对于每个通道 C，按照 indices 选择相应的特征点,为矩阵的索引操作。
    # 这样，feats_patch 的输出将是一个新矩阵，包含了通过 indices 选取的特征。其形状可以理解为 (C, patch_num, n)，即：
    # C 是特征的通道数，patch_num 是每个 patch 中的面索引，n是Face_num
    # 由此，下面四行相当于将每个面、面中心点
    feats_patch = feats[:, indices]
    center_patch = face_center[indices]
    cordinates_patch = face_coordinate[indices]
    faces_patch = mesh.faces[indices]


    # 操作的目的是将提取出的 patch 数据填充到固定大小。假设 patch_num=96 是当前处理的 patch 数量，而最终需要统一大小为 256 个 patch。
    # 怎么改？这里把256改为640（即全脑ic2对应的三角片的个数）
    feats_patcha = np.concatenate((feats_patch, np.zeros((10, 1024 - patch_num, 64), dtype=np.float32)), 1)
    center_patcha = np.concatenate((center_patch, np.zeros((1024 - patch_num, 64, 3), dtype=np.float32)), 0)
    cordinates_patcha = np.concatenate((cordinates_patch, np.zeros((1024 - patch_num, 64, 9), dtype=np.float32)), 0)
    faces_patcha = np.concatenate((faces_patch, np.zeros((1024 - patch_num, 64, 3), dtype=int)), 0)

    Fs_patcha = np.array(Fs)
    return feats_patcha, center_patcha, cordinates_patcha, faces_patcha, Fs_patcha

class ClassificationDataset(data.Dataset):
    def __init__(self, dataroot, train=True, augment=None):
        super().__init__()

        self.dataroot = Path(dataroot)
        self.augments = []
        self.mode = 'train' if train else 'test'
        self.feats = ['area', 'face_angles', 'curvs', 'normal']

        self.mesh_paths = []
        self.labels = []
        self.browse_dataroot()
        if train and augment:
            self.augments = augment

    def browse_dataroot(self):
        # 定义类别名称
        self.shape_classes = ['HCA', 'Numeric']

        # 初始化用于存储文件路径的列表
        self.mesh_paths = []

        # 遍历 dataroot 目录下的所有文件（不包括子目录）
        for obj_path in self.dataroot.iterdir():
            if obj_path.is_file():
                self.mesh_paths.append(obj_path)

        # 将列表转换为 NumPy 数组（如果需要）
        self.mesh_paths = np.array(self.mesh_paths)

    def __getitem__(self, idx):

        # label = self.labels[idx]

        if self.mode == 'train':

            feats, center, cordinates, faces, Fs, label = load_mesh(self.mesh_paths[idx], augments=self.augments,
                                                                    request=self.feats)

            return feats, center, cordinates, faces, Fs, label, str(self.mesh_paths[idx])
        else:

            feats, center, cordinates, faces, Fs, label = load_mesh(self.mesh_paths[idx],
                                                                    augments=self.augments,
                                                                    request=self.feats)
            return feats, center, cordinates, faces, Fs, label, str(self.mesh_paths[idx])

    def __len__(self):
        return len(self.mesh_paths)


class SegmentationDataset(data.Dataset):
    def __init__(self, dataroot, train=True, augments=None):
        super().__init__()

        self.dataroot = dataroot

        self.augments = []
        # if train and augments:
        # self.augments = augments
        self.augments = augments
        self.mode = 'train' if train else 'test'
        self.feats = ['area', 'face_angles', 'curvs', 'normal']

        self.mesh_paths = []
        self.raw_paths = []
        self.seg_paths = []
        self.browse_dataroot()

    # self.set_attrs(total_len=len(self.mesh_paths))

    def browse_dataroot(self):
        for dataset in (Path(self.dataroot) / self.mode).iterdir():
            if dataset.is_dir():
                for obj_path in dataset.iterdir():
                    if obj_path.suffix == '.obj':
                        obj_name = obj_path.stem
                        seg_path = obj_path.parent / (obj_name + '.json')

                        self.mesh_paths.append(str(obj_path))

        self.mesh_paths = np.array(self.mesh_paths)

    def __getitem__(self, idx):

        if self.mode == 'train':

            faces_patcha, feats_patcha, Fs_patcha, center_patcha, cordinates_patcha, label_patcha = load_mesh_seg(
                self.mesh_paths[idx],
                normalize=True,
                augments=self.augments,
                request=self.feats)

            return faces_patcha, feats_patcha, Fs_patcha, center_patcha, cordinates_patcha, label_patcha
        else:
            faces_patcha, feats_patcha, Fs_patcha, center_patcha, cordinates_patcha, label_patcha = load_mesh_seg(
                self.mesh_paths[idx],
                normalize=True,
                request=self.feats)

            return faces_patcha, feats_patcha, Fs_patcha, center_patcha, cordinates_patcha, label_patcha

    def __len__(self):
        return len(self.mesh_paths)


class ShapeNetDataset(data.Dataset):
    def __init__(self, dataroot, train=True, augment=None):
        super().__init__()

        self.dataroot = Path(dataroot)
        self.augments = []
        self.feats = ['area', 'face_angles', 'curvs', 'normal']
        self.mesh_paths = []
        self.browse_dataroot()
        if train and augment:
            self.augments = augment

    # def browse_dataroot(self):
    #     self.shape_classes = [x.name for x in self.dataroot.iterdir() if x.is_dir()]
    #
    #     for obj_class in self.dataroot.iterdir():
    #         if obj_class.is_dir():
    #             for obj_path in (obj_class).iterdir():
    #                 if obj_path.is_file():
    #                     self.mesh_paths.append(obj_path)
    #
    #     self.mesh_paths = np.array(self.mesh_paths)

    def browse_dataroot(self):
        # 定义类别名称
        self.shape_classes = ['HCA', 'Numeric']

        # 初始化用于存储文件路径的列表
        self.mesh_paths = []

        # 遍历 dataroot 目录下的所有文件（不包括子目录）
        for obj_path in self.dataroot.iterdir():
            if obj_path.is_file():
                self.mesh_paths.append(obj_path)

        # 将列表转换为 NumPy 数组（如果需要）
        self.mesh_paths = np.array(self.mesh_paths)

    def __getitem__(self, idx):
        label = 0
        feats, center, cordinates, faces, Fs = load_mesh_shape(self.mesh_paths[idx], augments=self.augments,
                                                               request=self.feats)

        return feats, center, cordinates, faces, Fs, label, str(self.mesh_paths[idx])

    def __len__(self):
        return len(self.mesh_paths)
