{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/myCenterCode\n",
      "noisy_modelnet_file =  /myCenterCode/my_dataset/ModelNet/my_ModelNet10/noisy_modelnet200_5_1_0.1_1234.pkl\n",
      "数据已保存至 /myCenterCode/my_dataset/ModelNet/my_ModelNet10/noisy_modelnet200_5_1_0.1_1234.pkl\n",
      "数据已加载\n",
      "4899 4899 4899\n"
     ]
    }
   ],
   "source": [
    "import os,sys\n",
    "import numpy as np\n",
    "from scipy.stats import multivariate_normal\n",
    "import pickle\n",
    "import trimesh\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
    "sys.path.append(parent_dir)\n",
    "print(parent_dir)\n",
    "from __tool.tool import save_data,load_data\n",
    "\n",
    "def read_off_file(file_path):\n",
    "    \"\"\"\n",
    "    读取 .off 文件并提取顶点和面信息\n",
    "    \"\"\"\n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "        assert lines[0].strip() == \"OFF\", \"Not a valid OFF file.\"\n",
    "        \n",
    "        # 提取顶点和面的数量\n",
    "        n_verts, n_faces, _ = map(int, lines[1].strip().split())\n",
    "        \n",
    "        # 读取顶点坐标\n",
    "        vertices = np.array([list(map(float, lines[i + 2].strip().split())) for i in range(n_verts)])\n",
    "        \n",
    "        # 读取面（如果需要）\n",
    "        faces = [list(map(int, lines[i + 2 + n_verts].strip().split()))[1:] for i in range(n_faces)]\n",
    "        \n",
    "        return vertices, faces\n",
    "\n",
    "def sample_points_from_mesh(vertices, faces, num_points=2048):\n",
    "    \"\"\"\n",
    "    将3D网格转换为点集\n",
    "    \"\"\"\n",
    "    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)\n",
    "    points, _ = trimesh.sample.sample_surface(mesh, num_points)\n",
    "    points = points / (np.max(points) - np.min(points)) * 100\n",
    "    points = points - np.mean(points,axis=0)\n",
    "    return points\n",
    "\n",
    "def load_modelnet_dataset(base_dir, num_points=900, is_modelnet10=False):\n",
    "    \"\"\"\n",
    "    加载 ModelNet 数据集中的所有模型并转化为点集\n",
    "    :param base_dir: 数据集的根目录路径\n",
    "    :param num_points: 每个模型点云的数量\n",
    "    :param is_modelnet10: 是否是 ModelNet10 数据集\n",
    "    \"\"\"\n",
    "    points_list = []\n",
    "    label_list = []\n",
    "    \n",
    "    # ModelNet40 目录结构: 每个类别为一个文件夹，包含 train 和 test 子文件夹\n",
    "    if not is_modelnet10:\n",
    "        # 读取 ModelNet40 类别目录\n",
    "        for label_dir in sorted(os.listdir(base_dir)):\n",
    "            class_dir = os.path.join(base_dir, label_dir)\n",
    "            if not os.path.isdir(class_dir):\n",
    "                continue\n",
    "\n",
    "            # 处理 train 子目录\n",
    "            train_dir = os.path.join(class_dir, 'train')\n",
    "            if os.path.isdir(train_dir):\n",
    "                for file_name in os.listdir(train_dir):\n",
    "                    if file_name.endswith(\".off\"):\n",
    "                        file_path = os.path.join(train_dir, file_name)\n",
    "                        try:\n",
    "                            # 读取 .off 文件并转化为点云\n",
    "                            vertices, faces = read_off_file(file_path)\n",
    "                            points = sample_points_from_mesh(vertices, faces, num_points)\n",
    "                            # 保存训练集点云和标签\n",
    "                            points_list.append(points)\n",
    "                            # print(np.min(points),np.max(points))\n",
    "                            label_list.append(f\"{label_dir}_train\")\n",
    "                        except Exception as e:\n",
    "                            print(f\"Error reading {file_path}: {e}\")\n",
    "                            continue\n",
    "\n",
    "            # 处理 test 子目录\n",
    "            test_dir = os.path.join(class_dir, 'test')\n",
    "            if os.path.isdir(test_dir):\n",
    "                for file_name in os.listdir(test_dir):\n",
    "                    if file_name.endswith(\".off\"):\n",
    "                        file_path = os.path.join(test_dir, file_name)\n",
    "                        try:\n",
    "                            # 读取 .off 文件并转化为点云\n",
    "                            vertices, faces = read_off_file(file_path)\n",
    "                            points = sample_points_from_mesh(vertices, faces, num_points)\n",
    "\n",
    "                            # 保存测试集点云和标签\n",
    "                            points_list.append(points)\n",
    "                            # print(np.min(points),np.max(points))\n",
    "                            label_list.append(f\"{label_dir}_test\")\n",
    "                        except Exception as e:\n",
    "                            print(f\"Error reading {file_path}: {e}\")\n",
    "                            continue\n",
    "\n",
    "    else:\n",
    "        # ModelNet10 目录结构: 每个类别为数字文件夹\n",
    "        for label in sorted(os.listdir(base_dir)):\n",
    "            class_dir = os.path.join(base_dir, label)\n",
    "            if not os.path.isdir(class_dir):\n",
    "                continue\n",
    "            \n",
    "            # 读取 train 和 test 子目录\n",
    "            for subset in ['train', 'test']:\n",
    "                subset_dir = os.path.join(class_dir, subset)\n",
    "                if os.path.isdir(subset_dir):\n",
    "                    for file_name in os.listdir(subset_dir):\n",
    "                        if file_name.endswith(\".off\"):\n",
    "                            file_path = os.path.join(subset_dir, file_name)\n",
    "                            try:\n",
    "                                # 读取 .off 文件并转化为点云\n",
    "                                vertices, faces = read_off_file(file_path)\n",
    "                                points = sample_points_from_mesh(vertices, faces, num_points)\n",
    "\n",
    "                                # 保存点云和标签\n",
    "                                label_list.append(f\"{label}_{subset}\")\n",
    "                                points_list.append(points)\n",
    "                            except Exception as e:\n",
    "                                print(f\"Error reading {file_path}: {e}\")\n",
    "                                continue\n",
    "\n",
    "    return np.array(points_list), np.array(label_list)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def add_noise_to_points(points_list,label_list, num_clusters=5, cluster_radius=1, zeta=0.1,ipynb=False):\n",
    "    \"\"\"\n",
    "    向点云数据中添加噪声点\n",
    "    :param points_list: 原始点云数据\n",
    "    :param num_clusters: 噪声团的个数\n",
    "    :param cluster_radius: 噪声团的半径\n",
    "    :param zeta: 噪声总质量\n",
    "    :return: 添加噪声后的点云数据及权重\n",
    "    \"\"\"\n",
    "    noisy_points_list = []\n",
    "    noisy_weights_list = []\n",
    "    for points,label in zip(points_list,label_list):\n",
    "        # 计算当前点云的质量（点的个数）\n",
    "        num_points = len(points)\n",
    "        # print(np.max(points))\n",
    "        # 生成噪声点\n",
    "        noise_points = []\n",
    "        for _ in range(num_clusters):\n",
    "            # 噪声团的中心\n",
    "            center = np.random.uniform(low=-100, high=100, size=(3,))\n",
    "            cov = np.eye(3) * cluster_radius**2  # 协方差矩阵\n",
    "            gaussian = multivariate_normal(mean=center, cov=cov)\n",
    "\n",
    "            # 计算每个噪声团的噪声点数量\n",
    "            num_noise_points = int(zeta / num_clusters * num_points / (1-zeta))\n",
    "            noise = gaussian.rvs(size=num_noise_points)\n",
    "            noise_points.append(noise)\n",
    "\n",
    "        # 合并噪声点和原始点\n",
    "        noise_points = np.vstack(noise_points)\n",
    "        all_points = np.vstack([points, noise_points])\n",
    "        noisy_points_list.append(all_points)\n",
    "        \n",
    "        if ipynb == True and len(noisy_weights_list) % 1000 == 0:\n",
    "            import matplotlib.pyplot as plt\n",
    "            from mpl_toolkits.mplot3d import Axes3D\n",
    "            # 创建 3D 图形\n",
    "            fig = plt.figure()\n",
    "            ax = fig.add_subplot(111, projection='3d')\n",
    "            # 绘制三维点集\n",
    "            ax.scatter(np.vstack([points, noise_points])[:,0], np.vstack([points, noise_points])[:,1], np.vstack([points, noise_points])[:,2], c='r', marker='o')  # 红色点\n",
    "            # 添加标签\n",
    "            ax.set_xlabel('X Label')\n",
    "            ax.set_ylabel('Y Label')\n",
    "            ax.set_zlabel('Z Label')\n",
    "            plt.title(f\"label{label}\")\n",
    "            # 显示图形\n",
    "            plt.show()\n",
    "    \n",
    "    weights_list = [np.ones(len(list(ll)))/len(list(ll)) for ll in points_list]\n",
    "    noisy_weights_list = [np.ones(len(list(ll)))/len(list(ll)) for ll in noisy_points_list]\n",
    "    noisy_label_list = label_list\n",
    "    return points_list,weights_list,label_list, noisy_points_list, noisy_weights_list, noisy_label_list\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# 示例用法\n",
    "modelnet_dir = '/myCenterCode/my_dataset/ModelNet/ModelNet10'  # 设置 ModelNet 数据集路径\n",
    "num_clusters = 5  # 噪声团的个数\n",
    "cluster_radius = 1  # 噪声团的半径\n",
    "zeta = 0.1  # 噪声总质量\n",
    "clean_supportSize = 200\n",
    "seed = 1234; np.random.seed(seed)\n",
    "# 1. 加载 ModelNet 数据\n",
    "\n",
    "\n",
    "noisy_modelnet_path = \"/myCenterCode/my_dataset/ModelNet/my_ModelNet10/\"\n",
    "noisy_modelnet_file = noisy_modelnet_path + \"noisy_modelnet\" + str(clean_supportSize) + \"_\" + str(num_clusters) + \"_\"+ str(cluster_radius)+ \"_\"+ str(zeta)+ \"_\"+ str(seed) + \".pkl\"\n",
    "print(\"noisy_modelnet_file = \",noisy_modelnet_file)\n",
    "if not(os.path.exists(noisy_modelnet_file)):\n",
    "    points_list, label_list = load_modelnet_dataset(modelnet_dir,clean_supportSize)\n",
    "    \n",
    "    # points_list, label_list = points_list[:100], label_list[:100]\n",
    "    ## delete\n",
    "    # 2. 为每个点云添加噪声\n",
    "    points_list,weights_list,label_list, noisy_points_list, noisy_weights_list, noisy_label_list = add_noise_to_points(points_list,label_list, num_clusters, cluster_radius, zeta, ipynb=False)\n",
    "    # 3. 保存数据\n",
    "    my_data_list = [points_list,weights_list,label_list, noisy_points_list, noisy_weights_list, noisy_label_list]\n",
    "    save_data(my_data_list,noisy_modelnet_file)\n",
    "\n",
    "# 4. 加载数据\n",
    "my_data_list = load_data(noisy_modelnet_file)\n",
    "points_list,weights_list,label_list, noisy_points_list, noisy_weights_list, noisy_label_list = my_data_list\n",
    "print(len(noisy_points_list), len(noisy_weights_list), len(noisy_label_list))  # 输出加载后的数据长度\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
