import matplotlib.pyplot as plt
import numpy as np
# %matplotlib auto
import nibabel as nib
from nibabel import nifti1
from nibabel.viewers import OrthoSlicer3D
import os
import cv2
from tqdm import tqdm
import time
import open3d as o3d
import pandas as pd
import glob
import shutil
import random
import copy
from sklearn.decomposition import PCA
import gc

from copy import deepcopy
print(time.asctime())


def split_dataset(uniquebinganhao, ratio):
    random.seed(42)
    random.shuffle(uniquebinganhao)
    print(uniquebinganhao)
    split_index = int(len(uniquebinganhao) * ratio)
    train_set = uniquebinganhao[:split_index]
    test_set = uniquebinganhao[split_index:]
    return train_set, test_set

def read_nii_gz(example_filename, scale=0.8):
    img = nib.load(example_filename)
    img_data = np.array(img.get_fdata())

    affinemat = img.affine * scale

    volume_x = int(img_data.shape[0] * abs(affinemat[0, 0]))
    volume_y = int(img_data.shape[1] * abs(affinemat[1, 1]))
    volume_z = int(img_data.shape[2] * abs(affinemat[2, 2]))
    img_data_volume = np.zeros((volume_x, volume_y, volume_z))


    for i in range(img_data_volume.shape[0]):
        # print(i)
        for j in range(img_data_volume.shape[1]):
            for k in range(img_data_volume.shape[2]):
                img_data_volume[i, j, k] = img_data[int(i / abs(affinemat[0, 0])), int(j / abs(affinemat[1, 1])), int(k / abs(affinemat[2, 2]))]

    return img_data_volume

def generate_normal(img_data, xyz, color):
    count = 0
    nxnynz = []
    for i in range(xyz.shape[0]):
        # print(img_data.shape)
        # print(xyz[i][0],xyz[i][1],xyz[i][2])
        component_id = int(img_data[int(xyz[i][0]), int(xyz[i][1]), int(xyz[i][2])])
        color_id = int(color[i][0] / 0.1 + 1)

        if color[i][0] == 1.0 and color[i][1] == 0 and color[i][2] == 0:
            color_id = 11
        if color[i][0] == 0 and color[i][1] == 1.0 and color[i][2] == 0:
            color_id = 12
        if color[i][0] == 0 and color[i][1] == 0 and color[i][2] == 1.0:
            color_id = 13

        if component_id != color_id:
            count += 1
            # print(component_id,color_id)
        # print(xyz.shape)
        nxnynz.append(get_point_normal(color_id, img_data, int(xyz[i][0]), int(xyz[i][1]), int(xyz[i][2])))
    # print(count)
    return nxnynz

def get_point_normal(conponent_id, img_data_, i, j, k):
    qidian_list = [[i, j, k]]
    zhongdian_list = [[i, j, k]]
    for i_change in [-1, 0, 1]:
        for j_change in [-1, 0, 1]:
            for k_change in [-1, 0, 1]:
                if abs(img_data_[i + i_change, j + j_change, k + k_change] - conponent_id) < 0.5:
                    qidian_list.append([i + i_change, j + j_change, k + k_change])
                if abs(img_data_[i + i_change, j + j_change, k + k_change] - conponent_id) > 0.5:
                    zhongdian_list.append([i + i_change, j + j_change, k + k_change])
    if len(qidian_list) == 0 or len(zhongdian_list) == 0:
        ccc = 1
        # print(conponent_id,len(qidian_list),len(zhongdian_list))

    normal = np.array(zhongdian_list[0]) - np.array(qidian_list[0])
    normal_out = normal
    distance = normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2
    for zhongdian in zhongdian_list:
        for qidian in qidian_list:
            normal = np.array(zhongdian) - np.array(qidian)
            if normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2 > distance:
                distance = normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2
                normal_out = normal
    return normal_out

def from_xyz_load_resample_recon(img_data, xyz_load, color_load, number_of_points):
    nxnynz = generate_normal(img_data, np.array(xyz_load), color_load)

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz_load)
    pcd.colors = o3d.utility.Vector3dVector(color_load)
    pcd.normals = o3d.utility.Vector3dVector(nxnynz)
    pcd.estimate_normals()
    # o3d.visualization.draw_geometries([pcd], point_show_normal=True)

    # print('run Poisson surface reconstruction')
    # with o3d.utility.VerbosityContextManager(
    #         o3d.utility.VerbosityLevel.Debug) as cm:
    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)


    pcd2 = mesh.sample_points_poisson_disk(number_of_points=number_of_points)

    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd2, depth=9)
    return pcd2, mesh

from mpl_toolkits.mplot3d import Axes3D
import gc
import glob
def get_slice_e1_e2_center(example_filename, slice_type='a4c'):

    patient_id = example_filename.split('/')[-2]
    out_file = '../data/CTA/cta_slice/'+slice_type+'/'
    # e1e2 = np.load(out_file + os.path.basename(example_filename).split('.nii')[-2] + '.npz')
    e1e2 = np.load(out_file+ os.path.basename(example_filename).split('.')[-2] + '.npz')
    e1 = e1e2['e1']
    e2 = e1e2['e2']
    img3d_center = e1e2['img3d_center']

    img_path = out_file + os.path.basename(example_filename).split('.')[-2] + '.png'
    shutil.copy(img_path,
                '../data/CTA/cta_normal/pointcloud/slice/' + slice_type + '/' + os.path.basename(
                    img_path))
    print(img_path)
    img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
    if len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    return e1, e2, img3d_center, img

def is_con_xy(img, i, j):
    img = np.array(img, dtype='float')
    if i + 1 < img.shape[0]:
        if abs(img[i + 1, j][0] - img[i, j][0]) > 0.05:
            return 1
    if i - 1 >= 0:
        if abs(img[i - 1, j][0] - img[i, j][0]) > 0.05:
            return 1
    if j - 1 >= 0:
        if abs(img[i, j - 1][0] - img[i, j][0]) > 0.05:
            return 1
    if j + 1 < img.shape[1]:
        if abs(img[i, j + 1][0] - img[i, j][0]) > 0.05:
            return 1
    return 0
def pick_point(img3d_center, e1, e2, img, color_list=[], xyzd=[]):
    img2d_center = [img.shape[0] // 2, img.shape[1] // 2]
    for i in tqdm(range(img.shape[0] // 2)):
        for j in range(img.shape[1] // 2):
            if img[img2d_center[0] + i, img2d_center[1] + j][0] > 0:
                if is_con_xy(img, img2d_center[0] + i, img2d_center[1] + j):
                    # print(np.array(np.array(img3d_center)+e1*[i]+e2*[j],dtype='int'),img[img2d_center[0]+i,img2d_center[1]+j])
                    xyzd.append(np.array(np.array(img3d_center) + e1 * [i] + e2 * [j], dtype='int'))
                    color_list.append(img[img2d_center[0] + i, img2d_center[1] + j] / 255)
            if img[img2d_center[0] - i, img2d_center[1] + j][0] > 0:
                if is_con_xy(img, img2d_center[0] - i, img2d_center[1] + j):
                    xyzd.append(np.array(np.array(img3d_center) - e1 * [i] + e2 * [j], dtype='int'))
                    color_list.append(img[img2d_center[0] - i, img2d_center[1] + j] / 255)
            if img[img2d_center[0] - i, img2d_center[1] - j][0] > 0:
                if is_con_xy(img, img2d_center[0] - i, img2d_center[1] - j):
                    xyzd.append(np.array(np.array(img3d_center) - e1 * [i] - e2 * [j], dtype='int'))
                    color_list.append(img[img2d_center[0] - i, img2d_center[1] - j] / 255)
            if img[img2d_center[0] + i, img2d_center[1] - j][0] > 0:
                if is_con_xy(img, img2d_center[0] + i, img2d_center[1] - j):
                    xyzd.append(np.array(np.array(img3d_center) + e1 * [i] - e2 * [j], dtype='int'))
                    color_list.append(img[img2d_center[0] + i, img2d_center[1] - j] / 255)

    return color_list, xyzd

if __name__ == '__main__':
    pathlist = glob.glob('../data/CTA/cta_normal/pointcloud/all/0*%*.ply')
    alreadypath_list = glob.glob("../data/CTA/cta_normal/pointcloud/partial/0*_*_a4c.ply")
    already_list = [path.split('/')[-1].split('_a4c')[0] for path in alreadypath_list]
    # pathL=[]
    # for path in pathlist:
    #     if path not in nopathlist:
    #         pathL.append(path)
    # pathlist=pathL

    donotprocesspathlist =[]
    donotprocessnamelist = []

    pathlist_end=[]
    for path in pathlist:
        if path.split('/')[-1] not in donotprocessnamelist:
            pathlist_end.append(path)
    print(len(pathlist),len(pathlist_end))
    pathlist=pathlist_end

    for i_plypath in range(len(pathlist[:])):

        gc.collect()
        plypath=pathlist[i_plypath].replace('cta_normal/pointcloud/all','cta_ply')
        if plypath.split('/')[-1].split('.')[0] in already_list:
            continue
        print('>>>>>>',i_plypath,'/',len(pathlist[:]),plypath)

        pcd_input = o3d.io.read_point_cloud(plypath)
        xyz = np.array(pcd_input.points)
        color = np.array(pcd_input.colors)
        # nxnynz = np.array(pcd_input.normals)
        xyz_load7, color_load7, nxnynz_load7=[],[],[]
        xyz_load8, color_load8, nxnynz_load8 = [], [], []
        xyz_load9, color_load9, nxnynz_load9= [], [],[]

        for i in range(color.shape[0]):
            if color[i, 0] - color[i, 1] > 0.001:
                xyz_load7.append(xyz[i, :])
                color_load7.append(color[i, :])
                # nxnynz_load7.append(nxnynz[i, :])
            elif color[i, 1] - color[i, 0] > 0.001:
                xyz_load8.append(xyz[i, :])
                color_load8.append(color[i, :])
                # nxnynz_load8.append(nxnynz[i, :])
            elif color[i, 2] - color[i, 0] > 0.001:
                xyz_load9.append(xyz[i, :])
                color_load9.append(color[i, :])
                # nxnynz_load9.append(nxnynz[i, :])
        pcd_bm = o3d.geometry.PointCloud()
        pcd_bm.points = o3d.utility.Vector3dVector(xyz_load7+xyz_load8+xyz_load9)
        pcd_bm.colors = o3d.utility.Vector3dVector(color_load7+color_load8+color_load9)
        # pcd_bm.normals = o3d.utility.Vector3dVector(nxnynz_load7+nxnynz_load8+nxnynz_load9)

        pcd_all = o3d.io.read_point_cloud(
            "../data/CTA/cta_chongcaiyang/pointcloud/all/" + plypath.split('/')[-1])
        xyz = np.array(pcd_all.points)
        color = np.array(pcd_all.colors)
        nxnynz = np.array(pcd_all.normals)
        # o3d.visualization.draw_geometries([pcd_all], point_show_normal=True)
        example_filename = os.path.join("../data/CTA/cta_mat", plypath.split('/')[-1].split('.ply')[0] + ".mat")
        print(example_filename)
        slice_type_list = ['a4c', 'a2c', 'a3c', 'a5c', 'lax']
        for slice_type in slice_type_list:
            e1, e2, img3d_center, img = get_slice_e1_e2_center(example_filename, slice_type=slice_type)
            # plt.imshow(img)
            plane_normal = np.cross(e2, e1)
            xyz_partial, color_partial, nxnynz_partial = [], [], []
            for i_point in range(len(xyz)):
                point_ = xyz[i_point]
                color_ = color[i_point] * 0.5
                nxnynz_ = nxnynz[i_point]
                if abs(plane_normal[0] * (point_[0] - img3d_center[0]) +
                       plane_normal[1] * (point_[1] - img3d_center[1]) +
                       plane_normal[2] * (point_[2] - img3d_center[2])) < 2:
                    xyz_partial.append(point_)
                    color_partial.append(color_)
                    nxnynz_partial.append(nxnynz_)

            pcd_partial = o3d.geometry.PointCloud()
            pcd_partial.points = o3d.utility.Vector3dVector(xyz_partial)
            pcd_partial.colors = o3d.utility.Vector3dVector(color_partial)
            pcd_partial.normals = o3d.utility.Vector3dVector(nxnynz_partial)

            xyz_load1, color_load1, xyz_load2, color_load2 = [], [], [], []
            for i in range(color.shape[0]):
                if abs(color[i, 0]) < 0.01 and abs(color[i, 1]) < 0.01 and abs(color[i, 2]) < 0.01:
                    xyz_load1.append(xyz[i, :])
                    color_load1.append(color[i, :])
                elif abs(color[i, 0] - 0.1) < 0.01 and abs(color[i, 1] - 0.1) < 0.01 and abs(color[i, 2] - 0.1) < 0.01:
                    xyz_load2.append(xyz[i, :])
                    color_load2.append(color[i, :])
            print(np.array(xyz_load1).shape, np.array(xyz_load8).shape, np.array(xyz_load7).shape)

            center_point = np.array(
                [np.array(xyz_load1)[:, 0].mean(), np.array(xyz_load1)[:, 1].mean(), np.array(xyz_load1)[:, 2].mean()])
            print(center_point)
            distance = 0
            for point in xyz_load2:
                new_distance = (center_point[0] - point[0]) ** 2 + (center_point[1] - point[1]) ** 2 + (
                            center_point[2] - point[2]) ** 2
                if new_distance > distance:
                    distance = new_distance
                    xj_point = point
            print(xj_point)

            y_axis = np.array(xj_point) - center_point


            sjb_point = np.array(
                [np.array(xyz_load8)[:, 0].mean(), np.array(xyz_load8)[:, 1].mean(), np.array(xyz_load8)[:, 2].mean()])
            ejb_point = np.array(
                [np.array(xyz_load7)[:, 0].mean(), np.array(xyz_load7)[:, 1].mean(), np.array(xyz_load7)[:, 2].mean()])
            x_axis = ejb_point - sjb_point


            z_axis = np.cross(x_axis, y_axis)


            y_axis = np.cross(z_axis, x_axis)
            # axis=o3d.geometry.TriangleMesh.create_coordinate_frame(size=10,origin=[100,100,100])
            x_axis = x_axis / ((x_axis[0] ** 2 + x_axis[1] ** 2 + x_axis[2] ** 2) ** 0.5)
            y_axis = y_axis / ((y_axis[0] ** 2 + y_axis[1] ** 2 + y_axis[2] ** 2) ** 0.5)
            z_axis = z_axis / ((z_axis[0] ** 2 + z_axis[1] ** 2 + z_axis[2] ** 2) ** 0.5)

            T = np.eye(4)
            t = ejb_point * 0.5 + sjb_point * 0.5
            R = np.linalg.inv(np.array([x_axis, y_axis, z_axis]))
            R_inv = np.linalg.inv(R)
            # R=axis.get_rotation_matrix_from_xyz((np.pi/2,0,0))
            # print(R)
            T[:3, :3] = R
            T[:3, 3] = t
            # f1=o3d.geometry.TriangleMesh.create_coordinate_frame(size=10,origin=[0,0,0]).rotate(R,center=(0,0,0)).translate(t)

            all_center_point = np.array(
                [np.array(xyz)[:, 0].mean(), np.array(xyz)[:, 1].mean(), np.array(xyz)[:, 2].mean()])
            pcd_all_copy = copy.deepcopy(pcd_all)
            pcd_normal = pcd_all_copy.translate(-1 * all_center_point).rotate(R_inv, center=(0, 0, 0))
            pcd_normal = pcd_normal.scale(1 / 300, (0, 0, 0))
            o3d.io.write_point_cloud(
                "../data/CTA/cta_normal/pointcloud/all/" + plypath.split('/')[-1].split('.ply')[
                    0] + '.ply', pcd_normal)

            pcd_normal_partial = copy.deepcopy(pcd_partial)
            pcd_normal_partial = pcd_partial.translate(-1 * all_center_point).rotate(R_inv, center=(0, 0, 0))
            pcd_normal_partial = pcd_normal_partial.scale(1 / 300, (0, 0, 0))
            o3d.io.write_point_cloud(
                "../data/CTA/cta_normal/pointcloud/partial/" + plypath.split('/')[-1].split('.ply')[
                    0] +'_'+slice_type + '.ply', pcd_normal_partial)
            print('>>>>>>>>>>>>>>>',"../data/CTA/cta_normal/pointcloud/partial/" + plypath.split('/')[-1].split('.ply')[
                    0] +'_'+slice_type + '.ply')

            pcd_bm_normal=copy.deepcopy(pcd_bm)
            pcd_bm_normal = pcd_bm_normal.translate(-1 * all_center_point).rotate(R_inv, center=(0, 0, 0))
            pcd_bm_normal =pcd_bm_normal.scale(1 / 300, (0, 0, 0))
            o3d.io.write_point_cloud(
                "../data/CTA/cta_normal/pointcloud/all_bm/" + plypath.split('/')[-1].split('.ply')[
                    0] + '.ply', pcd_bm_normal)

