import sys
sys.path.append('.')

import os
import random

import torch
import torch.utils.data as data
import numpy as np
import open3d as o3d


class ShapeNet_Heart(data.Dataset):
    """
    ShapeNet dataset in "PCN: Point Completion Network". It contains 28974 training
    samples while each complete samples corresponds to 8 viewpoint partial scans, 800
    validation samples and 1200 testing samples.
    """

    def __init__(self, dataroot, split, category):
        assert split in ['train', 'valid', 'test', 'test_novel'], "split error value!"

        self.dataroot = dataroot
        self.split = split
        self.category = category
        self.slice_dict={0:'a4c',1:'a2c',2:'a5c',3:'lax',4:'a3c'}

        self.partial_paths, self.complete_paths = self._load_data()

    def __getitem__(self, index):
        if self.split == 'train':

            partial_path = self.partial_paths[index].format(self.slice_dict[random.randint(0, 3)])
        else:
            partial_path = self.partial_paths[index].format(self.slice_dict[random.randint(0, 3)])
            # partial_path = self.partial_paths[index]
        complete_path = self.complete_paths[index]

        # partial_pc = self.random_sample(self.read_point_cloud(partial_path), 2048)
        partial_pc = self.random_sample(self.read_point_cloud(partial_path), 512)
        complete_pc = self.random_sample(self.read_point_cloud(complete_path), 16384)

        return torch.from_numpy(partial_pc), torch.from_numpy(complete_pc),partial_path

    def __len__(self):
        return len(self.complete_paths)

    def _load_data(self):

        with open(os.path.join(self.dataroot, '{}.list').format(self.split), 'r') as f:
            lines = f.read().splitlines()


        partial_paths, complete_paths = list(), list()
        if self.category != 'all':
            for line in lines:

                model_id = line
                # if self.split == 'train':

                # partial_paths.append(os.path.join(self.dataroot, 'partial', model_id + '_'+self.category+'.ply'))
                partial_paths.append(os.path.join(self.dataroot, 'partial', model_id + '_' + self.category + '.ply'))

                # else:
                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

                # complete_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))
                complete_paths.append(os.path.join(self.dataroot, 'component', model_id + '_lv.ply'))

                # complete_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))
        else:
            for line in lines:

                model_id = line
                # if self.split == 'train':


                partial_paths.append(os.path.join(self.dataroot, 'partial', model_id + '_{}.ply'))
                # else:

                #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

                # complete_paths.append(os.path.join(self.dataroot, 'all', model_id + '.ply'))
                complete_paths.append(os.path.join(self.dataroot, 'component', model_id + '_lv.ply'))

                # 如果是只关心瓣膜就all_bm
                # complete_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))

        return partial_paths, complete_paths

    def read_point_cloud(self, path):
        pc = o3d.io.read_point_cloud(path)
        return np.array(pc.points, np.float32)

    def read_point_cloud_component(self, path):
        pc = o3d.io.read_point_cloud(path)
        xyz=pc.points
        color=pc.points
        xyz_load6,color_load6=[],[]
        for i in range(color.shape[0]):
            if abs(color[i, 0] - 0.5) < 0.01 and abs(color[i, 1] - 0.5) < 0.01 and abs(color[i, 2] - 0.5) < 0.01:
                xyz_load6.append(xyz[i, :])
                color_load6.append(color[i, :])
        return np.array(xyz_load6, np.float32)

    def random_sample(self, pc, n):
        idx = np.random.permutation(pc.shape[0])
        if idx.shape[0] < n:
            idx = np.concatenate([idx, np.random.randint(pc.shape[0], size=n - pc.shape[0])])
        return pc[idx[:n]]

class ShapeNet_GM(data.Dataset):
    """
    ShapeNet dataset in "PCN: Point Completion Network". It contains 28974 training
    samples while each complete samples corresponds to 8 viewpoint partial scans, 800
    validation samples and 1200 testing samples.
    """

    def __init__(self, dataroot, split, category):
        assert split in ['train', 'valid', 'test', 'test_novel'], "split error value!"

        self.dataroot = dataroot
        self.split = split
        self.category = category
        self.slice_dict={0:'a4c',1:'a2c',2:'a5c',3:'lax'}

        self.partial_paths, self.complete_paths = self._load_data()

    def __getitem__(self, index):
        # if self.split == 'train':
        #     partial_path = self.partial_paths[index].format(self.slice_dict[random.randint(0, 3)])
        # else:
        partial_path = self.partial_paths[index]
        complete_path = self.complete_paths[index]

        partial_pc = self.random_sample(self.read_point_cloud(partial_path), 2048)
        # aro_pc = self.random_sample(self.read_point_cloud_component(partial_path), 2048)

        # partial_pc = self.random_sample(self.read_point_cloud(partial_path), 512)
        complete_pc = self.random_sample(self.read_point_cloud(complete_path), 16384)
        return torch.from_numpy(partial_pc), torch.from_numpy(complete_pc),partial_path #,torch.from_numpy(aro_pc)

    def __len__(self):
        return len(self.complete_paths)

    def _load_data(self):

        with open(os.path.join(self.dataroot, '{}.list').format(self.split+'_gm'), 'r') as f:
            lines = f.read().splitlines()
        # if self.category != 'all':
        #     lines = list(filter(lambda x: x.startswith(self.cat2id[self.category]), lines))

        partial_paths, complete_paths = list(), list()

        for line in lines:

            model_id = line
            # if self.split == 'train':


            partial_paths.append(os.path.join(self.dataroot, 'x_guanmai', model_id + '.ply'))
            # else:

            #     partial_paths.append(os.path.join(self.dataroot, self.split, 'partial', category, model_id + '.ply'))

            complete_paths.append(os.path.join(self.dataroot, 'y_clear_gm', model_id + '.ply'))

            # complete_paths.append(os.path.join(self.dataroot, 'all_bm', model_id + '.ply'))

        return partial_paths, complete_paths

    def read_point_cloud(self, path):
        pc = o3d.io.read_point_cloud(path)
        return np.array(pc.points, np.float32)

    # def read_point_cloud_component(self, path):
    #     pc = o3d.io.read_point_cloud(path)
    #     xyz=np.array(pc.points)
    #     color=np.array(pc.colors)
    #     xyz_load6,color_load6=[],[]
    #     for i in range(color.shape[0]):
    #         if abs(color[i, 0] - 0.5) < 0.01 and abs(color[i, 1] - 0.5) < 0.01 and abs(color[i, 2] - 0.5) < 0.01:
    #             xyz_load6.append(xyz[i, :])
    #             color_load6.append(color[i, :])
    #     return np.array(xyz_load6, np.float32)

    def random_sample(self, pc, n):
        idx = np.random.permutation(pc.shape[0])
        if idx.shape[0] < n:
            idx = np.concatenate([idx, np.random.randint(pc.shape[0], size=n - pc.shape[0])])
        return pc[idx[:n]]