#import matplotlib.pyplot as plt
import numpy as np
# %matplotlib auto
import nibabel as nib
#from nibabel import nifti1
#from nibabel.viewers import OrthoSlicer3D
import os
import cv2
from tqdm import tqdm
import time
import open3d as o3d
import pandas as pd
import glob
import shutil
import random
import copy
#from sklearn.decomposition import PCA
import gc
import scipy.io as sio
from copy import deepcopy
import multiprocessing as mp
print(time.asctime())


def split_dataset(uniquebinganhao, ratio):
    random.seed(42)
    random.shuffle(uniquebinganhao)
    print(uniquebinganhao)
    split_index = int(len(uniquebinganhao) * ratio)
    train_set = uniquebinganhao[:split_index]
    test_set = uniquebinganhao[split_index:]
    return train_set, test_set

def read_nii_gz(example_filename, scale=1.5):
    img = nib.load(example_filename)

    img_data = np.array(img.get_fdata())

    affinemat = img.affine * scale

    volume_x = int(img_data.shape[0] * abs(affinemat[0, 0]))
    volume_y = int(img_data.shape[1] * abs(affinemat[1, 1]))
    volume_z = int(img_data.shape[2] * abs(affinemat[2, 2]))
    img_data_volume = np.zeros((volume_x, volume_y, volume_z))

    for i_volume in tqdm(range(img_data_volume.shape[0])):
        # print(i)
        for j_volume in range(img_data_volume.shape[1]):
            for k_volume in range(img_data_volume.shape[2]):
                # print(type(i_volume), type(j_volume), k_volume)
                # index1=i_volume / abs(affinemat[0, 0])
                # index2=j_volume / abs(affinemat[1, 1])
                # index3=k_volume / abs(affinemat[2, 2])
                # img_data_volume[i_volume, j_volume, k_volume] = img_data[int(index1),int(index2),int(index3)]
                img_data_volume[i_volume, j_volume, k_volume] = img_data[int(i_volume / abs(affinemat[0, 0])),
                int(j_volume / abs(affinemat[1, 1])),int(k_volume / abs(affinemat[2, 2]))]
    del img
    del img_data
    sio.savemat("../data/CTA/cta_mat/"+ example_filename.split('/')[-1].replace('.nii.gz','.mat'),
                {'img_data_volume': img_data_volume})
    return img_data_volume #np.array(img_data_volume,dtype=np.int8)

def generate_normal(img_data, xyz, color):
    count = 0
    nxnynz = []
    print(xyz.shape,np.array(color).shape)
    for i in range(xyz.shape[0]):
        # print(img_data.shape)
        # print(xyz[i][0],xyz[i][1],xyz[i][2])
        component_id = int(img_data[int(xyz[i][0]), int(xyz[i][1]), int(xyz[i][2])])
        color_id = int(color[i][0] / 0.1 + 1)

        if color[i][0] == 1.0 and color[i][1] == 0 and color[i][2] == 0:
            color_id = 11
        if color[i][0] == 0 and color[i][1] == 1.0 and color[i][2] == 0:
            color_id = 12
        if color[i][0] == 0 and color[i][1] == 0 and color[i][2] == 1.0:
            color_id = 13

        if component_id != color_id:
            count += 1
            # print(component_id,color_id)
        # print(xyz.shape)
        nxnynz.append(get_point_normal(color_id, img_data, int(xyz[i][0]), int(xyz[i][1]), int(xyz[i][2])))
    # print(count)
    return nxnynz

def get_point_normal(conponent_id, img_data_, i, j, k):
    qidian_list = [[i, j, k]]
    zhongdian_list = [[i, j, k]]
    for i_change in [-1, 0, 1]:
        for j_change in [-1, 0, 1]:
            for k_change in [-1, 0, 1]:
                if abs(img_data_[i + i_change, j + j_change, k + k_change] - conponent_id) < 0.5:
                    qidian_list.append([i + i_change, j + j_change, k + k_change])
                if abs(img_data_[i + i_change, j + j_change, k + k_change] - conponent_id) > 0.5:
                    zhongdian_list.append([i + i_change, j + j_change, k + k_change])
    if len(qidian_list) == 0 or len(zhongdian_list) == 0:
        ccc = 1
        # print(conponent_id,len(qidian_list),len(zhongdian_list))

    normal = np.array(zhongdian_list[0]) - np.array(qidian_list[0])
    normal_out = normal
    distance = normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2
    for zhongdian in zhongdian_list:
        for qidian in qidian_list:
            normal = np.array(zhongdian) - np.array(qidian)
            if normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2 > distance:
                distance = normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2
                normal_out = normal
    return normal_out

def from_xyz_load_resample_recon(img_data, xyz_load, color_load, number_of_points):
    nxnynz = generate_normal(img_data, np.array(xyz_load), color_load)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz_load)
    pcd.colors = o3d.utility.Vector3dVector(color_load)
    pcd.normals = o3d.utility.Vector3dVector(nxnynz)
    pcd.estimate_normals()
    # o3d.visualization.draw_geometries([pcd], point_show_normal=True)

    # print('run Poisson surface reconstruction')
    # with o3d.utility.VerbosityContextManager(
    #         o3d.utility.VerbosityLevel.Debug) as cm:
    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)

    pcd2 = mesh.sample_points_poisson_disk(number_of_points=number_of_points)
    # o3d.io.write_point_cloud("D:\\231204CTA_process\\hole_r22.ply", pcd2)
    # with o3d.utility.VerbosityContextManager(
    #         o3d.utility.VerbosityLevel.Debug) as cm:
    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd2, depth=9)
    return pcd2, mesh

def get_normal_and_save(img_data,plypath,xyz_load, color_load, number_of_points, i):
    nxnynz = generate_normal(img_data, np.array(xyz_load), color_load)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz_load)
    pcd.colors = o3d.utility.Vector3dVector(color_load)
    pcd.normals = o3d.utility.Vector3dVector(nxnynz)

    pcd.estimate_normals()
    # o3d.visualization.draw_geometries([pcd], point_show_normal=True)

    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
    # axis=o3d.geometry.TriangleMesh.create_coordinate_frame(size=1,origin=[0,0,0])
    # o3d.visualization.draw_geometries([mesh,axis], point_show_normal=True)

    pcd2 = mesh.sample_points_poisson_disk(number_of_points=number_of_points)
    o3d.io.write_point_cloud(
        "../data/CTA/cta_chongcaiyang/pointcloud/partial/" + plypath.split('/')[-1].split('.ply')[0] +
        "_hole_%s.ply" % (i), pcd2)

    mesh2, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd2, depth=9)
    # axis=o3d.geometry.TriangleMesh.create_coordinate_frame(size=1,origin=[0,0,0])
    # o3d.visualization.draw_geometries([mesh,axis], point_show_normal=True)
    o3d.io.write_triangle_mesh(
        "../data/CTA/cta_chongcaiyang/mesh/partial/" + plypath.split('/')[-1].split('.ply')[0] +
        "_mesh_%s.ply" % (i), mesh2)

    return pcd2, mesh2

def get_normal_chongcaiyang(plypath):
    pcd_input = o3d.io.read_point_cloud(plypath)
    pcd_input = pcd_input.voxel_down_sample(voxel_size=1.5)
    # o3d.visualization.draw_geometries([pcd_all], point_show_normal=True)

    xyz = np.array(pcd_input.points)
    color = np.array(pcd_input.colors)

    binganhao = plypath.split('/')[-1].split('.')[0]
    example_filename = glob.glob('../data/CTA/cta_labels/*/' + binganhao + '.nii.gz')
    example_filename = example_filename[0]
    Xload = sio.loadmat(
        '../data/CTA/cta_mat/' + example_filename.split('/')[-1].split('.')[0] + '.mat')
    img_data = Xload['img_data_volume']

    xyz_load1, color_load1, xyz_load2, color_load2 = [], [], [], []
    xyz_load3, color_load3, xyz_load4, color_load4 = [], [], [], []
    xyz_load5, color_load5, xyz_load6, color_load6 = [], [], [], []
    xyz_load7, color_load7, xyz_load8, color_load8 = [], [], [], []

    for i in range(color.shape[0]):
        if abs(color[i, 0]) < 0.01 and abs(color[i, 1]) < 0.01 and abs(color[i, 2]) < 0.01:
            xyz_load1.append(xyz[i, :])
            color_load1.append(color[i, :])
        elif abs(color[i, 0] - 0.1) < 0.01 and abs(color[i, 1] - 0.1) < 0.01 and abs(color[i, 2] - 0.1) < 0.01:
            xyz_load2.append(xyz[i, :])
            color_load2.append(color[i, :])
        elif abs(color[i, 0] - 0.2) < 0.01 and abs(color[i, 1] - 0.2) < 0.01 and abs(color[i, 2] - 0.2) < 0.01:
            xyz_load3.append(xyz[i, :])
            color_load3.append(color[i, :])
        elif abs(color[i, 0] - 0.3) < 0.01 and abs(color[i, 1] - 0.3) < 0.01 and abs(color[i, 2] - 0.3) < 0.01:
            xyz_load4.append(xyz[i, :])
            color_load4.append(color[i, :])
        elif abs(color[i, 0] - 0.4) < 0.01 and abs(color[i, 1] - 0.4) < 0.01 and abs(color[i, 2] - 0.4) < 0.01:
            xyz_load5.append(xyz[i, :])
            color_load5.append(color[i, :])
        elif abs(color[i, 0] - 0.5) < 0.01 and abs(color[i, 1] - 0.5) < 0.01 and abs(color[i, 2] - 0.5) < 0.01:
            xyz_load6.append(xyz[i, :])
            color_load6.append(color[i, :])
        elif color[i, 0] - color[i, 1] > 0.001:
            xyz_load7.append(xyz[i, :])
            color_load7.append(color[i, :])
        elif color[i, 1] - color[i, 0] > 0.001:
            xyz_load8.append(xyz[i, :])
            color_load8.append(color[i, :])

    caiyanglist = [(xyz_load1, color_load1, 4038, 1), (xyz_load2, color_load2, 4660, 2),
                   (xyz_load3, color_load3, 3814, 3), (xyz_load4, color_load4, 1299, 4),
                   (xyz_load5, color_load5, 1480, 5), (xyz_load6, color_load6, 1139, 6)]

    pcd_chongcaiyang = o3d.geometry.PointCloud()
    mesh_chongcaiyang = o3d.geometry.TriangleMesh()

    for xyz_load, color_load, number_of_points, i in caiyanglist:
        pcd_component, mesh_component = get_normal_and_save(img_data, plypath, xyz_load, color_load, number_of_points,
                                                            i)

        pcd_chongcaiyang = pcd_chongcaiyang + pcd_component
        mesh_chongcaiyang = mesh_chongcaiyang + mesh_component
    del img_data
    # o3d.visualization.draw_geometries([pcd_], point_show_normal=True)
    o3d.io.write_point_cloud(
        "../data/CTA/cta_chongcaiyang/pointcloud/all/" + plypath.split('/')[-1], pcd_chongcaiyang)

    o3d.io.write_triangle_mesh("../data/CTA/cta_chongcaiyang/mesh/all/" + plypath.split('/')[-1],
                               mesh_chongcaiyang)


def main():
    gc.collect()
    pathlist = glob.glob('../data/CTA/cta_ply/0*%*.ply')
    print(len(pathlist))#,pathlist)

    # ---------------读入数据-------------------
    # plypath = pathlist[0]
    n=0

    example_filenameL,plypathL=[],[]
    for i_index in range(len(pathlist)):

        i_n=i_index+n
        plypath=pathlist[i_n]
        already_ply = glob.glob("../data/CTA/cta_chongcaiyang/pointcloud/all/0*_label_*.ply")
        # print(already_ply)
        alreadymat_ply = glob.glob("../data/CTA/cta_mat/0*_label_*.mat")
        # print(alreadymat_ply)
        # already_list = [path.split('/')[-1].split('.ply')[0] for path in already_ply]
        # already_list = ([path.split('/')[-1].split('.ply')[0] for path in already_ply] +
        #                 [path.split('/')[-1].split('.mat')[0] for path in alreadymat_ply])
        already_list = [path.split('/')[-1].split('.mat')[0] for path in alreadymat_ply]
        # print(len(already_list), already_list)
        if plypath.split('/')[-1].split('.')[0] not in already_list:
            print('>>>>>>>>>>>>>>>', i_n, plypath)
            binganhao = plypath.split('/')[-1].split('.')[0]

            example_filename1 = glob.glob('../data/CTA/cta_labels/*/'+ binganhao + '.nii.gz')

            # print('example_filename',example_filename1)
            example_filenameL.append(example_filename1[0])
            plypathL.append(plypath)

            if len(example_filenameL)==1:
                pool=mp.Pool(processes=1)
                results=pool.map(read_nii_gz,example_filenameL)
                pool.close()
                pool.join()
                # print(results)
                example_filenameL=[]
            # img_data = read_nii_gz(example_filename, scale=1.5)
            # if len(plypathL) == 1:
            #     pool = mp.Pool(processes=1)
            #     results=pool.map(get_normal_chongcaiyang,plypathL)
            #     pool.close()
            #     pool.join()
            #     # print(results)
            #     plypathL=[]

            # if len(example_filenameL)==3:
            #     pool=mp.Pool(processes=3)
            #     results=pool.map(read_nii_gz,example_filenameL)
            #     pool.close()
            #     pool.join()
            #     # print(results)
            #     example_filenameL=[]


if __name__ == '__main__':
    print('Start Retrying in 5 seconds...')
    time.sleep(5)
    while True:
        try:
            main()
            break
        except Exception as e:
            print('\033[1;31mAn error occurred: \033[0m',e)
            print('Retrying in 5 seconds...')
            time.sleep(5)

