import matplotlib.pyplot as plt
import numpy as np
# %matplotlib auto
import nibabel as nib
from nibabel import nifti1
from nibabel.viewers import OrthoSlicer3D
import os
import cv2
from tqdm import tqdm
import time
import open3d as o3d
import pandas as pd
import glob
import shutil
import random
import copy
from sklearn.decomposition import PCA
import gc

from copy import deepcopy
print(time.asctime())

def split_dataset(uniquebinganhao, ratio):
    random.seed(42)
    random.shuffle(uniquebinganhao)
    print(uniquebinganhao)
    split_index = int(len(uniquebinganhao) * ratio)
    train_set = uniquebinganhao[:split_index]
    test_set = uniquebinganhao[split_index:]
    return train_set, test_set

if __name__ == '__main__':
    gc.collect()
    pathlist = glob.glob('../data/CTA/cta_normal/pointcloud/all/0*_label_*.ply')
    print(len(pathlist),pathlist)
    # ---------------读入数据-------------------
    n=0

    for i_index in range(len(pathlist)):
        i_n=i_index+n
        plypath=pathlist[i_n]
        print(i_n,plypath)
        pcd_input = o3d.io.read_point_cloud(plypath)
        # o3d.visualization.draw_geometries([pcd_input], point_show_normal=True)

        xyz = np.array(pcd_input.points)
        color = np.array(pcd_input.colors)


        xyz_load1, color_load1, xyz_load2, color_load2 = [], [], [], []
        xyz_load3, color_load3, xyz_load4, color_load4 = [], [], [], []
        xyz_load5, color_load5, xyz_load6, color_load6 = [], [], [], []
        xyz_load7, color_load7, xyz_load8, color_load8 = [], [], [], []

        for i in range(color.shape[0]):
            if abs(color[i, 0]) < 0.01 and abs(color[i, 1]) < 0.01 and abs(color[i, 2]) < 0.01:
                xyz_load1.append(xyz[i, :])
                color_load1.append(color[i, :])
            elif abs(color[i, 0] - 0.1) < 0.01 and abs(color[i, 1] - 0.1) < 0.01 and abs(color[i, 2] - 0.1) < 0.01:
                xyz_load2.append(xyz[i, :])
                color_load2.append(color[i, :])
            elif abs(color[i, 0] - 0.2) < 0.01 and abs(color[i, 1] - 0.2) < 0.01 and abs(color[i, 2] - 0.2) < 0.01:
                xyz_load3.append(xyz[i, :])
                color_load3.append(color[i, :])
            elif abs(color[i, 0] - 0.3) < 0.01 and abs(color[i, 1] - 0.3) < 0.01 and abs(color[i, 2] - 0.3) < 0.01:
                xyz_load4.append(xyz[i, :])
                color_load4.append(color[i, :])
            elif abs(color[i, 0] - 0.4) < 0.01 and abs(color[i, 1] - 0.4) < 0.01 and abs(color[i, 2] - 0.4) < 0.01:
                xyz_load5.append(xyz[i, :])
                color_load5.append(color[i, :])
            elif abs(color[i, 0] - 0.5) < 0.01 and abs(color[i, 1] - 0.5) < 0.01 and abs(color[i, 2] - 0.5) < 0.01:
                xyz_load6.append(xyz[i, :])
                color_load6.append(color[i, :])
        print(len(xyz_load1),len(xyz_load2),
              len(xyz_load3),len(xyz_load4),
              len(xyz_load5),len(xyz_load6))

        xyz_load=[xyz_load1,xyz_load2,xyz_load3,xyz_load4,xyz_load5,xyz_load6]
        color_load=[color_load1,color_load2,color_load3,color_load4,color_load5,color_load6]
        # nxnynz=(nxnynz_load1,nxnynz_load2,nxnynz_load3,nxnynz_load4,nxnynz_load5,xyz_load6)

        com_type=['lv','myo','rv','la','ra','aro']
        for i in range(6):
            # 装入pcd
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(xyz_load[i])
            pcd.colors = o3d.utility.Vector3dVector(color_load[i])
            # pcd.normals = o3d.utility.Vector3dVector(nxnynz)
            o3d.io.write_point_cloud(
                "../data/CTA/cta_normal/pointcloud/component/"
                + plypath.split('/')[-1].split('.')[0]+'_'+com_type[i]+'.ply',
                pcd)