import os
import numpy as np

# import vtk
import gc
import glob
import copy
from sklearn.decomposition import PCA
# import matplotlib.pyplot as plt
import cv2
# import vtkmodules.all as vtk
# from vtk.util.numpy_support import vtk_to_numpy
# from vtk.util.numpy_support import numpy_to_vtk
# from vtk import (vtkPolyDataReader,vtkUnstructuredGridReader, vtkDataSetMapper, vtkActor,
#                  vtkRenderer, vtkRenderWindow, vtkRenderWindowInteractor,vtkJPEGWriter)
import matplotlib
# matplotlib.use('TkAgg')
from matplotlib import pylab as plt
# %matplotlib inline
import nibabel as nib
from nibabel import nifti1
from nibabel.viewers import OrthoSlicer3D
from copy import deepcopy
import time
from tqdm import tqdm
#guild filter

def read_nii_gz(example_filename, scale=1.5):
    img = nib.load(example_filename)

    img_data = np.array(img.get_fdata())

    affinemat = img.affine * scale

    volume_x = int(img_data.shape[0] * abs(affinemat[0, 0]))
    volume_y = int(img_data.shape[1] * abs(affinemat[1, 1]))
    volume_z = int(img_data.shape[2] * abs(affinemat[2, 2]))
    img_data_volume = np.zeros((volume_x, volume_y, volume_z))

    for i_volume in range(img_data_volume.shape[0]):
        # print(i)
        for j_volume in range(img_data_volume.shape[1]):
            for k_volume in range(img_data_volume.shape[2]):
                # print(type(i_volume), type(j_volume), k_volume)
                # index1=i_volume / abs(affinemat[0, 0])
                # index2=j_volume / abs(affinemat[1, 1])
                # index3=k_volume / abs(affinemat[2, 2])
                # img_data_volume[i_volume, j_volume, k_volume] = img_data[int(index1),int(index2),int(index3)]
                img_data_volume[i_volume, j_volume, k_volume] = img_data[int(i_volume / abs(affinemat[0, 0])),
                int(j_volume / abs(affinemat[1, 1])),int(k_volume / abs(affinemat[2, 2]))]

    del img
    del img_data

    return img_data_volume #np.array(img_data_volume,dtype=np.int8)


def is_con_xyz(img_data, i, j, k, conponent_id):
    if int(img_data[i, j, k])^ conponent_id < 0.5:
        if int(img_data[i + 1, j, k])^conponent_id > 0.5:
            return 1
        if int(img_data[i - 1, j, k])^conponent_id > 0.5:
            return 1
        if int(img_data[i, j - 1, k])^conponent_id > 0.5:
            return 1
        if int(img_data[i, j + 1, k])^conponent_id> 0.5:
            return 1
        if int(img_data[i, j, k + 1])^conponent_id> 0.5:
            return 1
        if int(img_data[i, j, k - 1])^conponent_id> 0.5:
            return 1
    return 0


def is_erjianban(img_data, i_ejb, j_ejb, k_ejb, conponent_id):
    conponent_id_in = 1
    # if img_data[i_ejb, j_ejb, k_ejb]^conponent_id_in < 0.5:
    if int(img_data[i_ejb, j_ejb, k_ejb]) ^ conponent_id_in < 0.5:
    # if abs(img_data[i_ejb, j_ejb, k_ejb]-conponent_id_in) < 0.5:
        conponent_id = 4
        # print(type(img_data),type(img_data[i_ejb + 1, j_ejb, k_ejb]))
        # if img_data[i_ejb + 1, j_ejb, k_ejb]^conponent_id < 0.5:
        # if abs(img_data[i_ejb + 1, j_ejb, k_ejb]-conponent_id) < 0.5:
        if int(img_data[i_ejb + 1, j_ejb, k_ejb]) ^ conponent_id < 0.5:
            return 1
        # if img_data[i_ejb - 1, j_ejb, k_ejb]^conponent_id < 0.5:
        # if abs(img_data[i_ejb - 1, j_ejb, k_ejb]-conponent_id) < 0.5:
        if int(img_data[i_ejb - 1, j_ejb, k_ejb]) ^ conponent_id < 0.5:
            return 1
        if int(img_data[i_ejb, j_ejb - 1, k_ejb])^conponent_id < 0.5:
        # if abs(img_data[i_ejb, j_ejb - 1, k_ejb]-conponent_id) < 0.5:
            return 1
        if int(img_data[i_ejb, j_ejb + 1, k_ejb])^conponent_id < 0.5:
        # if abs(img_data[i_ejb, j_ejb + 1, k_ejb]-conponent_id) < 0.5:
            return 1
        if int(img_data[i_ejb, j_ejb, k_ejb + 1]) ^ conponent_id < 0.5:
        # if abs(img_data[i_ejb, j_ejb, k_ejb + 1] - conponent_id) < 0.5:
            return 1
        if int(img_data[i_ejb, j_ejb, k_ejb - 1])^conponent_id < 0.5:
        # if abs(img_data[i_ejb, j_ejb, k_ejb - 1] - conponent_id) < 0.5:
            return 1
    return 0


def is_sanjianban(img_data, i, j, k, conponent_id):
    conponent_id_in = 3
    if int(img_data[i, j, k])^conponent_id_in < 0.5:
        conponent_id = 5
        if int(img_data[i + 1, j, k])^conponent_id < 0.5:
            return 1
        if int(img_data[i - 1, j, k])^conponent_id < 0.5:#abs(img_data[i - 1, j, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j-1, k])^conponent_id < 0.5:#abs(img_data[i, j - 1, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i , j+ 1, k])^conponent_id < 0.5:#abs(img_data[i, j + 1, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j, k+1])^conponent_id < 0.5:#abs(img_data[i, j, k + 1] - conponent_id) < 0.5:
            return 1
        if int(img_data[i , j, k-1])^conponent_id < 0.5:#abs(img_data[i, j, k - 1] - conponent_id) < 0.5:
            return 1
    return 0


def is_zhudongmaiban(img_data, i, j, k, conponent_id):
    conponent_id_in = 1
    if int(img_data[i, j, k])^conponent_id_in < 0.5:
        conponent_id = 6
        if int(img_data[i + 1, j, k]) ^ conponent_id < 0.5:
            return 1
        if int(img_data[i - 1, j, k]) ^ conponent_id < 0.5:  # abs(img_data[i - 1, j, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j - 1, k]) ^ conponent_id < 0.5:  # abs(img_data[i, j - 1, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j + 1, k]) ^ conponent_id < 0.5:  # abs(img_data[i, j + 1, k] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j, k + 1]) ^ conponent_id < 0.5:  # abs(img_data[i, j, k + 1] - conponent_id) < 0.5:
            return 1
        if int(img_data[i, j, k - 1]) ^ conponent_id < 0.5:  # abs(img_data[i, j, k - 1] - conponent_id) < 0.5:
            return 1
    return 0


def generate_conponent(img_data):
    xyz_load1,xyz_load2,xyz_load3,xyz_load4,xyz_load5 = [],[],[],[],[]
    xyz_load6, xyz_load7, xyz_load8, xyz_load9, xyz_load10= [], [], [], [], []
    xyz_erjianbanload1 = []
    xyz_sanjianbanload1 = []
    xyz_zhudongmaibanload1 = []

    for i in tqdm(range(1, img_data.shape[0] - 1)):
        for j in range(1, img_data.shape[1] - 1):
            for k in range(1, img_data.shape[2] - 1):
                if is_erjianban(img_data, i, j, k, 1):
                    xyz_erjianbanload1.append([i + 0.2, j, k])
                if is_sanjianban(img_data, i, j, k, 1):
                    xyz_sanjianbanload1.append([i + 0.2, j, k])
                if is_zhudongmaiban(img_data, i, j, k, 1):
                    xyz_zhudongmaibanload1.append([i + 0.2, j, k])
                if is_con_xyz(img_data, i, j, k, 1):
                    xyz_load1.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 2):
                    xyz_load2.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 3):
                    xyz_load3.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 4):
                    xyz_load4.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 5):
                    xyz_load5.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 6):
                    xyz_load6.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 7):
                    xyz_load7.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 8):
                    xyz_load8.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 9):
                    xyz_load9.append([i, j, k])
                    continue
                if is_con_xyz(img_data, i, j, k, 10):
                    xyz_load10.append([i, j, k])

    print('load conponents end!')

    return xyz_load1, xyz_load2, xyz_load3, xyz_load4, xyz_load5, xyz_load6, xyz_load7, xyz_load8, xyz_load9, xyz_load10, xyz_erjianbanload1, xyz_sanjianbanload1, xyz_zhudongmaibanload1

example_filename_list = glob.glob("../data/CTA/cta_labels/labels*/*.nii.gz")
print(len(example_filename_list))
already_ply=glob.glob("../data/CTA/cta_ply/0*_label_*.ply")
print(already_ply)
already_list=[path.split('/')[-1].split('.ply')[0] for path in already_ply]
print(len(already_list),already_list)

import open3d as o3d

count = 0
n = 0
for example_filename in example_filename_list[n:]:
    if example_filename.split('/')[-1].split('.nii')[0] not in already_list:
        print(example_filename.split('/')[-1], '[', time.asctime(), ']', count + n, 'of', len(example_filename_list))
        img_data = read_nii_gz(example_filename, scale=1.5)

        (xyz_load1, xyz_load2, xyz_load3, xyz_load4, xyz_load5,
         xyz_load6, xyz_load7, xyz_load8, xyz_load9, xyz_load10,
         xyz_11, xyz_12, xyz_13) = generate_conponent(img_data)

        flag = 0
        for xyz_ in [xyz_load1, xyz_load2, xyz_load3, xyz_load4, xyz_load5,
                     xyz_load6, xyz_load7, xyz_load8, xyz_load9,
                     xyz_load10, xyz_11, xyz_12, xyz_13]:
            if xyz_ == []:
                flag = 1

        if flag == 1:
            continue

        xyz = np.concatenate((xyz_load1, xyz_load2, xyz_load3, xyz_load4, xyz_load5, xyz_load6, xyz_load7, xyz_load8,
                              xyz_load9, xyz_load10, xyz_11, xyz_12, xyz_13), axis=0)
        # xyz=np.concatenate((xyz_11,xyz_12,xyz_13),axis=0)
        color1 = np.ones(np.array(xyz_load1).shape) * 0.0
        color2 = np.ones(np.array(xyz_load2).shape) * 0.1
        color3 = np.ones(np.array(xyz_load3).shape) * 0.2
        color4 = np.ones(np.array(xyz_load4).shape) * 0.3
        color5 = np.ones(np.array(xyz_load5).shape) * 0.4
        color6 = np.ones(np.array(xyz_load6).shape) * 0.5
        color7 = np.ones(np.array(xyz_load7).shape) * 0.6
        color8 = np.ones(np.array(xyz_load8).shape) * 0.7
        color9 = np.ones(np.array(xyz_load9).shape) * 0.8
        color10 = np.ones(np.array(xyz_load10).shape) * 0.9
        color11,color12,color13 = [],[],[]

        for i in range(len(xyz_11)):
            color11.append([1.0, 0, 0])
        for i in range(len(xyz_12)):
            color12.append([0, 1.0, 0])
        for i in range(len(xyz_13)):
            color13.append([0, 0, 1.0])
        color = np.concatenate(
            (color1, color2, color3, color4, color5,
             color6, color7, color8, color9, color10,
             color11, color12, color13),
            axis=0)
        # color=np.concatenate((color11,color12,color13),axis=0)

        # xyz1=np.concatenate((xyz,np.ones((xyz.shape[0],1))),axis=1)
        # img_affine = img.affine
        # xyz=np.dot(img_affine,np.array(xyz1).T)[:3,:].T

        pcd = o3d.geometry.PointCloud()
        # print(xyz.dtype,xyz)
        pcd.points = o3d.utility.Vector3dVector(xyz)
        pcd.colors = o3d.utility.Vector3dVector(color)
        # o3d.visualization.draw_geometries([pcd])

        out_file = "../data/CTA/cta_ply/"
        ply_path = out_file + os.path.basename(example_filename).split('.nii')[-2] + '.ply'
        o3d.io.write_point_cloud(ply_path, pcd)
        count += 1
        del pcd
        del img_data
        gc.collect()
    else:
        count+=1