import matplotlib.pyplot as plt
import numpy as np
import matplotlib
#matplotlib.use('TkAgg')
from matplotlib import pylab as plt
# %matplotlib auto
import nibabel as nib
from nibabel import nifti1
from nibabel.viewers import OrthoSlicer3D
import os
import cv2
from tqdm import tqdm
import scipy.io as sio
import time
import open3d as o3d

print(time.asctime())
import glob
from mpl_toolkits.mplot3d import Axes3D
import gc
import glob


def read_nii_gz(example_filename, scale=1.5):
    img = nib.load(example_filename)

    img_data = np.array(img.get_fdata())

    affinemat = img.affine * scale

    volume_x = int(img_data.shape[0] * abs(affinemat[0, 0]))
    volume_y = int(img_data.shape[1] * abs(affinemat[1, 1]))
    volume_z = int(img_data.shape[2] * abs(affinemat[2, 2]))
    img_data_volume = np.zeros((volume_x, volume_y, volume_z))
    for i in tqdm(range(img_data_volume.shape[0])):
        # print(i)
        for j in range(img_data_volume.shape[1]):
            for k in range(img_data_volume.shape[2]):
                img_data_volume[i, j, k] = img_data[
                    int(i / abs(affinemat[0, 0])), int(j / abs(affinemat[1, 1])), int(k / abs(affinemat[2, 2]))]


    return img_data_volume


def get_slice_voxel(img3d_data, img3d_center, e1, e2):
    # nL=int(1.5*(max(img3d_center[0],img3d_data.shape[0]-img3d_center[0])**2+
    #         max(img3d_center[1],img3d_data.shape[1]-img3d_center[1])**2+
    #         max(img3d_center[2],img3d_data.shape[2]-img3d_center[2])**2)**0.5)
    n_shape = int(max(img3d_data.shape[0], img3d_data.shape[2]) * 1.2)
    print(img3d_data.shape, n_shape)
    img2d = np.zeros((n_shape, n_shape))
    for i in range(n_shape // 2 - 1):
        for j in range(n_shape // 2 - 1):
            for fuhao1, fuhao2 in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
                i_3d = int(img3d_center[0] + fuhao1 * e1[0] * i + fuhao2 * e2[0] * j)
                j_3d = int(img3d_center[1] + fuhao1 * e1[1] * i + fuhao2 * e2[1] * j)
                k_3d = int(img3d_center[2] + fuhao1 * e1[2] * i + fuhao2 * e2[2] * j)
                img3d_index = np.array(img3d_center + fuhao1 * e1 * i + fuhao2 * e2 * j, np.uint8)
                if i_3d < img3d_data.shape[0] and j_3d < img3d_data.shape[1] and k_3d < img3d_data.shape[
                    2] and i_3d > 0 and j_3d > 0 and k_3d > 0:
                    img2d[n_shape // 2 + fuhao1 * i, n_shape // 2 + fuhao2 * j] = img3d_data[i_3d, j_3d, k_3d]

    img2d = np.array(img2d, np.uint8)
    out_file = '../data/CTA/cta_slice/' + slice_type
    print(os.path.join(out_file, os.path.basename(example_filename).split('.')[-2] + '.png'))
    img2d = cv2.normalize(img2d, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    cv2.imencode('.png', img2d)[1].tofile(
        os.path.join(out_file, os.path.basename(example_filename).split('.')[-2] + '.png'))
    return img2d

if __name__ == '__main__':
    ply_path_list = glob.glob("../data/CTA/cta_normal/pointcloud/all/0*_*.ply")
    alreadypath_list=glob.glob("../data/CTA/cta_normal/pointcloud/partial/0*_*_a4c.ply")
    already_list=[path.split('/')[-1].split('_a4c')[0] for path in alreadypath_list]
    print(len(already_list))
    print(len(ply_path_list))
    n = 0
    for i in range(len(ply_path_list)):
        i_index = i + n
        plyP = ply_path_list[i_index].replace('cta_normal/pointcloud/all','cta_ply')
        if plyP.split('/')[-1].split('.')[0] in already_list:
            continue
        # plybmP=plyP.replace('/all/','/all_bm/')

        example_filename = "../data/CTA/cta_mat/" + os.path.basename(plyP).split('.ply')[0] + ".mat"
        print('>>>>>>>>>>',i_index, len(ply_path_list), plyP)#, plybmP) #example_filename)
        pcd_all = o3d.io.read_point_cloud(plyP)
        xyz = pcd_all.points
        xyz = np.array(xyz)
        color = np.array(pcd_all.colors)
        print(color.shape[0], np.unique(color))
        # nxnynz = np.array(pcd_all.normals)
        # pcd_all.points=o3d.utility.Vector3dVector(xyz)
        # o3d.visualization.draw_geometries([pcd_all], point_show_normal=True)
        xyz_load1, color_load1, xyz_load2, color_load2 = [], [], [], []
        xyz_load3, color_load3, xyz_load9, color_load9 = [], [], [], []
        xyz_load4, color_load4, xyz_load5, color_load5 = [], [], [], []
        xyz_load7, color_load7, xyz_load8, color_load8 = [], [], [], []

        for i in range(color.shape[0]):
            if abs(color[i, 0]) < 0.01 and abs(color[i, 1]) < 0.01 and abs(color[i, 2]) < 0.01:
                xyz_load1.append(xyz[i, :])
                color_load1.append(color[i, :])
            elif abs(color[i, 0] - 0.1) < 0.01 and abs(color[i, 1] - 0.1) < 0.01 and abs(color[i, 2] - 0.1) < 0.01:
                xyz_load2.append(xyz[i, :])
                color_load2.append(color[i, :])
            elif abs(color[i, 0] - 0.2) < 0.01 and abs(color[i, 1] - 0.2) < 0.01 and abs(color[i, 2] - 0.2) < 0.01:
                xyz_load3.append(xyz[i, :])
                color_load3.append(color[i, :])
            elif abs(color[i, 0] - 0.3) < 0.01 and abs(color[i, 1] - 0.3) < 0.01 and abs(color[i, 2] - 0.3) < 0.01:
                xyz_load4.append(xyz[i, :])
                color_load4.append(color[i, :])
            elif abs(color[i, 0] - 0.4) < 0.01 and abs(color[i, 1] - 0.4) < 0.01 and abs(color[i, 2] - 0.4) < 0.01:
                xyz_load5.append(xyz[i, :])
                color_load5.append(color[i, :])

        # pcd_bm = o3d.io.read_point_cloud(plybmP)
        # xyz = pcd_bm.points
        # xyz = np.array(xyz)
        # color = np.array(pcd_bm.colors)
        for i in range(color.shape[0]):
            if color[i, 0] - color[i, 1] > 0.001:
                xyz_load7.append(xyz[i, :])
                color_load7.append(color[i, :])
            elif color[i, 1] - color[i, 0] > 0.001:
                xyz_load8.append(xyz[i, :])
                color_load8.append(color[i, :])
            elif color[i, 2] - color[i, 0] > 0.001:
                xyz_load9.append(xyz[i, :])
                color_load9.append(color[i, :])
        del pcd_all
        # del pcd_bm
        print(np.array(xyz_load1).shape, np.array(xyz_load2).shape, np.array(xyz_load3).shape, np.array(xyz_load8).shape,
              np.array(xyz_load7).shape)

        rvcenter_point = np.array(
            [np.array(xyz_load3)[:, 0].mean(), np.array(xyz_load3)[:, 1].mean(), np.array(xyz_load3)[:, 2].mean()])


        zf_point = np.array(
            [np.array(xyz_load4)[:, 0].mean(), np.array(xyz_load4)[:, 1].mean(), np.array(xyz_load4)[:, 2].mean()])

        yf_point = np.array(
            [np.array(xyz_load5)[:, 0].mean(), np.array(xyz_load5)[:, 1].mean(), np.array(xyz_load5)[:, 2].mean()])


        center_point = np.array(
            [np.array(xyz_load1)[:, 0].mean(), np.array(xyz_load1)[:, 1].mean(), np.array(xyz_load1)[:, 2].mean()])

        distance = 0
        for point in xyz_load2:
            new_distance = (center_point[0] - point[0]) ** 2 + (center_point[1] - point[1]) ** 2 + (
                    center_point[2] - point[2]) ** 2
            if new_distance > distance:
                distance = new_distance
                xj_point = point


        sjb_point = np.array(
            [np.array(xyz_load8)[:, 0].mean(), np.array(xyz_load8)[:, 1].mean(), np.array(xyz_load8)[:, 2].mean()])
        ejb_point = np.array(
            [np.array(xyz_load7)[:, 0].mean(), np.array(xyz_load7)[:, 1].mean(), np.array(xyz_load7)[:, 2].mean()])
        zdmb_point = np.array(
            [np.array(xyz_load9)[:, 0].mean(), np.array(xyz_load9)[:, 1].mean(), np.array(xyz_load9)[:, 2].mean()])


        esjb_center = (ejb_point + sjb_point) * 0.5


        # img3d_data = read_nii_gz(example_filename, scale=1.5)
        img3d_data=sio.loadmat(example_filename)['img_data_volume']

        # a4c:
        img3d_center = esjb_center
        slice_type = 'a4c'
        e1 = esjb_center - xj_point
        normal = np.cross(xj_point - ejb_point, sjb_point - ejb_point)
        e2 = np.cross(normal, e1)
        e1 = e1 / ((e1[0] ** 2 + e1[1] ** 2 + e1[2] ** 2) ** 0.5)
        e2 = e2 / ((e2[0] ** 2 + e2[1] ** 2 + e2[2] ** 2) ** 0.5)

        out_file = '../data/CTA/cta_slice/' + slice_type
        if not os.path.exists(out_file):
            os.makedirs(out_file)
        print(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'))
        np.savez(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'),
                 plypath=plyP, niipath=example_filename, e1=e1, e2=e2,
                 img3d_center=img3d_center, normal=normal)

        img2d = get_slice_voxel(img3d_data, img3d_center, e1, e2)

        # a2c:
        img3d_center = ejb_point
        slice_type = 'a2c'
        e1 = ejb_point - xj_point
        normal = rvcenter_point - center_point
        e2 = np.cross(e1, normal)
        e1 = e1 / ((e1[0] ** 2 + e1[1] ** 2 + e1[2] ** 2) ** 0.5)
        e2 = e2 / ((e2[0] ** 2 + e2[1] ** 2 + e2[2] ** 2) ** 0.5)

        out_file = '../data/CTA/cta_slice/' + slice_type
        if not os.path.exists(out_file):
            os.makedirs(out_file)
        print(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'))
        np.savez(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'),
                 plypath=plyP, niipath=example_filename, e1=e1, e2=e2,
                 img3d_center=img3d_center, normal=normal)
        # img3d_data=read_nii_gz(example_filename,scale=1.5)
        img2d = get_slice_voxel(img3d_data, img3d_center, e1, e2)

        # a3c:
        img3d_center = (center_point + ejb_point) / 2
        slice_type = 'a3c'
        e1 = ejb_point - xj_point
        normal = np.cross(center_point - ejb_point, zdmb_point - ejb_point)
        e2 = np.cross(e1, normal)
        e1 = e1 / ((e1[0] ** 2 + e1[1] ** 2 + e1[2] ** 2) ** 0.5)
        e2 = e2 / ((e2[0] ** 2 + e2[1] ** 2 + e2[2] ** 2) ** 0.5)

        out_file = '../data/CTA/cta_slice/' + slice_type
        if not os.path.exists(out_file):
            os.makedirs(out_file)
        print(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'))
        np.savez(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'),
                 plypath=plyP, niipath=example_filename, e1=e1, e2=e2,
                 img3d_center=img3d_center, normal=normal)
        # img3d_data=read_nii_gz(example_filename,scale=1.5)
        img2d = get_slice_voxel(img3d_data, img3d_center, e1, e2)

        # lax:
        img3d_center = (center_point + ejb_point) / 2
        slice_type = 'lax'
        e2 = zdmb_point - xj_point
        normal = np.cross(zdmb_point - ejb_point, center_point - ejb_point)
        e1 = np.cross(e2, normal)

        e1 = e1 / ((e1[0] ** 2 + e1[1] ** 2 + e1[2] ** 2) ** 0.5)
        e2 = e2 / ((e2[0] ** 2 + e2[1] ** 2 + e2[2] ** 2) ** 0.5)

        normal = np.cross(e1, e2)
        out_file = '../data/CTA/cta_slice/' + slice_type
        if not os.path.exists(out_file):
            os.makedirs(out_file)
        print(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'))
        np.savez(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'),
                 plypath=plyP, niipath=example_filename, e1=e1, e2=e2,
                 img3d_center=img3d_center, normal=normal)
        # img3d_data=read_nii_gz(example_filename,scale=1.5)
        img2d = get_slice_voxel(img3d_data, img3d_center, e1, e2)

        # a5c:
        img3d_center = zdmb_point
        slice_type = 'a5c'
        e1 = esjb_center - xj_point
        normal = np.cross(yf_point - zf_point, zdmb_point - zf_point)
        e2 = np.cross(e1, normal)

        e1 = e1 / ((e1[0] ** 2 + e1[1] ** 2 + e1[2] ** 2) ** 0.5)
        e2 = e2 / ((e2[0] ** 2 + e2[1] ** 2 + e2[2] ** 2) ** 0.5)

        normal = np.cross(e1, e2)
        out_file = '../data/CTA/cta_slice/' + slice_type
        if not os.path.exists(out_file):
            os.makedirs(out_file)
        print(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'))
        np.savez(os.path.join(out_file, os.path.basename(plyP).split('.ply')[0] + '.npz'),
                 plypath=plyP, niipath=example_filename, e1=e1, e2=e2,
                 img3d_center=img3d_center, normal=normal)
        # img3d_data=read_nii_gz(example_filename,scale=1.5)
        img2d = get_slice_voxel(img3d_data, img3d_center, e1, e2)
