import sys

sys.path.append('.')

import os
import random

import torch
import torch.utils.data as data
import numpy as np
import open3d as o3d


class ShapeNet_Heart_Slice(data.Dataset):
    """
    ShapeNet dataset in "PCN: Point Completion Network". It contains 28974 training
    samples while each complete samples corresponds to 8 viewpoint partial scans, 800
    validation samples and 1200 testing samples.
    """

    def __init__(self, dataroot, split, category):
        assert split in ['train', 'valid', 'test', 'test_novel'], "split error value!"

        self.dataroot = dataroot
        self.split = split
        self.category = category
        self.slice_dict={0:'a4c',1:'a2c',2:'a5c',3:'lax'}

        self.partial_slice_paths, self.complete_slice_paths, self.complete_shape_paths = self._load_data()

    def __getitem__(self, index):
        if self.split == 'train':
            slice_index=random.randint(0, 3)
            partial_slice_path = self.partial_slice_paths[index].format(self.slice_dict[slice_index])
            complete_slice_path = self.complete_slice_paths[index].format(self.slice_dict[slice_index])
        else:
            slice_index = random.randint(0, 3)
            partial_slice_path = self.partial_slice_paths[index].format(self.slice_dict[slice_index])
            complete_slice_path = self.complete_slice_paths[index].format(self.slice_dict[slice_index])
        complete_shape_path = self.complete_shape_paths[index]

        # print(partial_slice_path,complete_slice_path)

        partial_slice_pc = self.random_sample(self.read_point_cloud(partial_slice_path), 512)
        complete_slice_pc = self.random_sample(self.read_point_cloud(complete_slice_path), 2048)
        # complete_shape_pc, complete_shape_id= self.random_sample(self.read_point_cloud_type(complete_shape_path), 16384)
        complete_shape_pc = self.random_sample(self.read_point_cloud(complete_shape_path), 16384)

        return torch.from_numpy(partial_slice_pc), torch.from_numpy(complete_slice_pc),torch.from_numpy(complete_shape_pc),partial_slice_path

    def __len__(self):
        return len(self.complete_shape_paths)

    def _load_data(self):
        with open(os.path.join(self.dataroot, '{}.list').format(self.split), 'r') as f:
            lines = f.read().splitlines()

        # if self.category != 'all':
        #     lines = list(filter(lambda x: x.startswith(self.cat2id[self.category]), lines))

        partial_slice_paths, complete_slice_paths, complete_shape_paths = list(), list(), list()
        if self.category != 'all':
            for line in lines:
                model_id = line
                # if self.split == 'train':
                partial_slice_paths.append(os.path.join(self.dataroot, '2d_slicepcd', model_id + '_' + self.category+'.ply'))
                # partial_slice_paths.append(os.path.join(self.dataroot, 'partial_slice', model_id + '_{}.ply'))
                complete_slice_paths.append(os.path.join(self.dataroot, 'partial', model_id + '_' + self.category+'.ply'))
                # else:

                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

                complete_shape_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))

                # complete_shape_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))
        else:
            for line in lines:
                model_id = line

                partial_slice_paths.append(os.path.join(self.dataroot, '2d_slicepcd', model_id + '_{}.ply'))
                # partial_slice_paths.append(os.path.join(self.dataroot, 'partial_slice', model_id + '_{}.ply'))
                complete_slice_paths.append(os.path.join(self.dataroot, 'partial', model_id + '_{}.ply'))

                # else:
                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

                complete_shape_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))

                # complete_shape_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))

        return partial_slice_paths, complete_slice_paths, complete_shape_paths

    def read_point_cloud(self, path):
        pc = o3d.io.read_point_cloud(path)
        return np.array(pc.points, np.float32)

    def read_point_cloud_type(self, path):
        pc = o3d.io.read_point_cloud(path)
        point_array=np.array(pc.points, np.float32)
        type_array=np.array(pc.colors, np.float32)[:,0]
        type_id_array=np.array([int((type_float+0.001)//0.1) for type_float in type_array], np.float32).reshape((point_array.shape[0],1))
        # print(type_array.shape,type_id_array.shape,np.unique(type_array))
        # print(np.concatenate((point_array,type_id_array), axis=1))
        return point_array,type_id_array

    def random_sample(self, pc, n):
        idx = np.random.permutation(pc.shape[0])
        if idx.shape[0] < n:
            idx = np.concatenate([idx, np.random.randint(pc.shape[0], size=n - pc.shape[0])])
        return pc[idx[:n]]


class ShapeNet_Heart_Slice_components(data.Dataset):
    """
    ShapeNet dataset in "PCN: Point Completion Network". It contains 28974 training
    samples while each complete samples corresponds to 8 viewpoint partial scans, 800
    validation samples and 1200 testing samples.
    """

    def __init__(self, dataroot, split, category):
        assert split in ['train', 'valid', 'test', 'test_novel'], "split error value!"

        self.dataroot = dataroot
        self.split = split
        self.category = category
        self.slice_dict = {0: 'a4c', 1: 'a3c',  2: 'lax',3: 'a2c',  4: 'a5c'}

        (self.partial_slice_paths,
         self.complete_slice_paths,
         self.complete_shape_paths,
         self.complete_lv_paths,
         self.complete_rv_paths,
         self.complete_aro_paths,
         self.complete_la_paths,
         self.complete_ra_paths,
         self.complete_myo_paths) = self._load_data()

    def __getitem__(self, index):
        if self.split == 'train':
            # _{}代表的后缀在这里形式化了
            slice_index = random.randint(0, 4)
            partial_slice_path = self.partial_slice_paths[index].format(self.slice_dict[slice_index])
            complete_slice_path = self.complete_slice_paths[index].format(self.slice_dict[slice_index])
        else:
            slice_index = random.randint(0, 4)
            partial_slice_path = self.partial_slice_paths[index].format(self.slice_dict[slice_index])
            complete_slice_path = self.complete_slice_paths[index].format(self.slice_dict[slice_index])
        complete_shape_path = self.complete_shape_paths[index]
        complete_lv_path = self.complete_lv_paths[index]
        complete_rv_path = self.complete_rv_paths[index]
        complete_aro_path = self.complete_aro_paths[index]
        complete_la_path = self.complete_la_paths[index]
        complete_ra_path = self.complete_ra_paths[index]
        complete_myo_path = self.complete_myo_paths[index]
        # print(partial_slice_path,complete_slice_path)

        partial_slice_pc = self.random_sample(self.read_point_cloud(partial_slice_path), 512)

        complete_slice_pc = self.random_sample(self.read_point_cloud(complete_slice_path), 2048)

        # complete_shape_pc, complete_shape_id= self.random_sample(self.read_point_cloud_type(complete_shape_path), 16384)
        complete_shape_pc = self.random_sample(self.read_point_cloud(complete_shape_path), 16384)
        complete_lv_pc = self.random_sample(self.read_point_cloud(complete_lv_path), 16384)
        complete_rv_pc= self.random_sample(self.read_point_cloud(complete_rv_path), 16384)
        complete_aro_pc= self.random_sample(self.read_point_cloud(complete_aro_path), 16384)
        complete_la_pc= self.random_sample(self.read_point_cloud(complete_la_path), 16384)
        complete_ra_pc= self.random_sample(self.read_point_cloud(complete_ra_path), 16384)
        complete_myo_pc= self.random_sample(self.read_point_cloud(complete_myo_path), 16384)
        return (torch.from_numpy(partial_slice_pc),
                torch.from_numpy(complete_slice_pc),
                torch.from_numpy(complete_shape_pc),
                torch.from_numpy(complete_lv_pc),
                torch.from_numpy(complete_rv_pc),
                torch.from_numpy(complete_aro_pc),
                torch.from_numpy(complete_la_pc),
                torch.from_numpy(complete_ra_pc),
                torch.from_numpy(complete_myo_pc),
                partial_slice_path)

    def __len__(self):
        return len(self.complete_shape_paths)

    def _load_data(self):
        with open(os.path.join(self.dataroot, '{}.list').format(self.split), 'r') as f:
            lines = f.read().splitlines()

        # if self.category != 'all':
        #     lines = list(filter(lambda x: x.startswith(self.cat2id[self.category]), lines))

        (partial_slice_paths,
         complete_slice_paths,
         complete_shape_paths,
         complete_lv_paths,
         complete_rv_paths,
         complete_aro_paths,
         complete_la_paths,
         complete_ra_paths,
         complete_myo_paths) = (list(), list(),  list(), list(),
                                list(),list(), list(), list(), list())
        if self.category != 'all':
            for line in lines:
                model_id = line
                # if self.split == 'train':

                # partial_slice_paths.append(
                #     os.path.join(self.dataroot, '2d_slicepcd', model_id + '_' + self.category + '.ply'))
                partial_slice_paths.append(
                    os.path.join(self.dataroot, 'partial_slice', model_id +'_' +self.category+'_03.ply'))
                complete_slice_paths.append(
                    os.path.join(self.dataroot, 'partial', model_id + '_' + self.category + '.ply'))
                complete_lv_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_lv.ply'))
                complete_rv_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_rv.ply'))
                complete_aro_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_aro.ply'))
                complete_la_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_la.ply'))
                complete_ra_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_ra.ply'))
                complete_myo_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_myo.ply'))
                # else:
                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

                complete_shape_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))
                # complete_shape_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))
        else:
            for line in lines:
                model_id = line
                # if self.split == 'train':

                # partial_slice_paths.append(os.path.join(self.dataroot, '2d_slicepcd', model_id + '_{}.ply'))
                partial_slice_paths.append(
                    os.path.join(self.dataroot, 'partial_slice', model_id + '_{}_02.ply'))
                complete_slice_paths.append(
                    os.path.join(self.dataroot, 'partial', model_id + '_{}.ply'))
                complete_lv_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_lv.ply'))
                complete_rv_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_rv.ply'))
                complete_aro_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_aro.ply'))
                complete_la_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_la.ply'))
                complete_ra_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_ra.ply'))
                complete_myo_paths.append(
                    os.path.join(self.dataroot, 'component', model_id + '_myo.ply'))
                # else:
                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))
                complete_shape_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))
                # complete_shape_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))

        return (partial_slice_paths,
                complete_slice_paths,
                complete_shape_paths,
                complete_lv_paths,
                complete_rv_paths,
                complete_aro_paths,
                complete_la_paths,
                complete_ra_paths,
                complete_myo_paths)

    def read_point_cloud(self, path):
        pc = o3d.io.read_point_cloud(path)
        return np.array(pc.points, np.float32)

    def read_point_cloud_type(self, path):
        pc = o3d.io.read_point_cloud(path)
        point_array = np.array(pc.points, np.float32)
        type_array = np.array(pc.colors, np.float32)[:, 0]
        type_id_array = np.array([int((type_float + 0.001) // 0.1) for type_float in type_array], np.float32).reshape(
            (point_array.shape[0], 1))
        # print(type_array.shape,type_id_array.shape,np.unique(type_array))
        # print(np.concatenate((point_array,type_id_array), axis=1))
        return point_array, type_id_array

    def random_sample(self, pc, n):
        idx = np.random.permutation(pc.shape[0])
        if idx.shape[0] < n:
            idx = np.concatenate([idx, np.random.randint(pc.shape[0], size=n - pc.shape[0])])
        return pc[idx[:n]]


