import random
import numpy as np
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import MinkowskiEngine as ME
import open3d as o3d
import math
import glob
import h5py
from prepare_data.utils import pc_utils
import time, pickle

def standarlize(coords):
    # mean_posi, min_posi, max_posi = coords.mean(0).values, coords.min(0).values, coords.max(0).values
    mean_posi, min_posi, max_posi = coords.mean(0), coords.min(0), coords.max(0)

    coords = torch.from_numpy(coords - mean_posi)

    scale_factor = max(max_posi - min_posi)
    scale = torch.eye(3) / scale_factor

    theta = 0 * math.pi
    # rotationz = torch.tensor([[1, 0, 0],
    #                           [0, math.cos(theta), math.sin(theta)],
    #                           [0, -math.sin(theta), math.cos(theta)]]).float()

    rotationx = torch.tensor([[math.cos(theta), math.sin(theta), 0],
                             [-math.sin(theta), math.cos(theta), 0],
                             [0, 0, 1]]).float()
    #
    # rotationy = torch.tensor([[math.cos(theta), 0, math.sin(theta)],
    #                          [0, 1, 0],
    #                          [math.sin(theta), 0, -math.cos(theta)]]).float()


    m = torch.matmul(scale, rotationx)
    coords = torch.matmul(coords.float(), m)
    return coords

def scannet_Merge(batch):
    # print(len(batch))
    bs = batch[0]['batchSize']
    if len(batch) < bs:
        batch = batch + [batch[0]] * (bs - len(batch))
    scannet = [item["scannet"] for item in batch]
    coords = ME.utils.batched_coordinates([item[0] for item in scannet], dtype=torch.float32)
    feats = torch.cat([item[1] for item in scannet], dim=0)
    labels = torch.cat([item[2] for item in scannet], dim=0)

    # coords, feats, labels = ME.utils.sparse_collate(coords, feats, labels)
    return {'coords': coords, 'feats': feats, 'labels': labels}


class scannet_Dataset(Dataset):
    def __init__(self, config, phase):
        self.scannet_root_dir = config.dataRoot_scannet
        if phase == 'train':
            self.scannet_file_list = self.read_files(config.train_file)
        else:
            self.scannet_file_list = self.read_files(config.val_file)

        self.voxel_size = config.voxel_size
        self.phase = phase
        self.config = config

    def read_files(self, file):
        f = open(file)
        lines = f.readlines()
        name_list = [line.split('.')[0] for line in lines]
        f.close()
        return name_list

    def __len__(self):
        return len(self.scannet_file_list)


    def __getitem__(self, idx):
        # _new_semantic.npy: 0~19, .npy: 1~20
        path = os.path.join(self.scannet_root_dir, self.scannet_file_list[idx], self.scannet_file_list[idx]+"_new_semantic.npy")
        # path = os.path.join(self.scannet_root_dir, self.file_list[idx], self.file_list[idx]+".npy")
        data = torch.from_numpy(np.load(path))
        # coords, feats, labels = data[:, :3], data[:, 3: 6], data[:, 9:]
        coords, labels = data[:, :3], data[:, 9:]

        # feats = feats / 127.5 - 1
        # feats = feats * 0 + 1
        feats = torch.ones(len(coords), 1)
        coords = (coords - coords.mean(0)) / self.voxel_size

        # for ii in range(5):
        #     model = sofas[ii, :, :]
        # pc_utils.write_ply(coords, "%s_model_rotation.ply" % idx, text=True)

        if self.phase == 'train':
            coords = pc_utils.random_rotation(coords)
            batchSize = self.config.batch_size_scannet
        else:
            batchSize = 1

        packages = {'scannet': (coords, feats, labels, self.scannet_file_list[idx]), 'batchSize': batchSize}
        return packages

def modelnet_Merge(batch):
    # chair, bed, sofa, table, nega_data
    model_list = batch[0]['config'].model_list

    bs = batch[0]['config'].batch_size_modelnet
    if len(batch) < bs:
        batch = batch + [batch[0]] * (bs - len(batch))

    data = {}
    for model in (model_list + ['nega_data']):
        data[model] = ME.utils.batched_coordinates([item[model] for item in batch], dtype=torch.float32)
    return data

    # chair = ME.utils.batched_coordinates([item['chair'] for item in batch], dtype=torch.float32)
    # bookshelf = ME.utils.batched_coordinates([item['bookshelf'] for item in batch], dtype=torch.float32)
    # sofa = ME.utils.batched_coordinates([item['sofa'] for item in batch], dtype=torch.float32)
    # table = ME.utils.batched_coordinates([item['table'] for item in batch], dtype=torch.float32)
    # nega_data = ME.utils.batched_coordinates([item['nega_data'] for item in batch], dtype=torch.float32)
    # return {'chair': chair, 'sofa': sofa, 'table': table, 'bookshelf': bookshelf, 'nega_data': nega_data}

class modelnet_Dataset(Dataset):
    def __init__(self, config):
        self.config = config
        self.root_dir = config.dataRoot_modelnet
        self.voxel_size = config.voxel_size
        self.num_points = 8192
        self.shape_names = self.load_shape_names(config.dataRoot_modelnet)
        self.modelnet_data = self.load_modelnet(config.dataRoot_modelnet)

    def load_shape_names(self, file):
        file = os.path.join(file, "shape_names.txt")
        f = open(file)
        shape_names = f.readlines()
        shape_names = [item.split("\n")[0] for item in shape_names]

        # print(shape_names)
        return shape_names

    def load_modelnet_b(self, data_root):
        data, labels = [], []
        assert os.path.exists(data_root), f"{data_root} does not exist"
        files = glob.glob(os.path.join(data_root, "ply_data_%s*.h5" % "train"))
        assert len(files) > 0, "No files found"
        for h5_name in files:
            with h5py.File(h5_name) as f:
                data.extend(f["data"][:].astype("float32"))
                labels.extend(f["label"][:].astype("int64"))
        data = np.stack(data, axis=0) / self.voxel_size
        labels = np.stack(labels, axis=0)

        # output_data_root = os.path.join(os.environ['HOME'], "dataset/modelnet40_PC")
        # for cla in self.shape_names:
        #     index = np.squeeze(labels == self.shape_names.index(cla))
        #     tems = data[index, :, :]
        #     models = []
        #     for i in range(len(tems)):
        #         print('%s %s/%s' %(cla, i, len(tems)))
        #         model = tems[i, :, :]
        #         model = standarlize(model)
        #         models.append(model)
        #     models = np.stack(models, axis=0)
        #     save_name = os.path.join(output_data_root, '%s_2048.npy' % cla)
        #     np.save(save_name, models)

        index_0 = np.squeeze(labels == self.shape_names.index('chair'))
        index_1 = np.squeeze(labels == self.shape_names.index('bed'))
        index_2 = np.squeeze(labels == self.shape_names.index('sofa'))
        index_3 = np.squeeze(labels == self.shape_names.index('table'))
        index_4 = np.squeeze(labels == self.shape_names.index('bookshelf'))
        index_5 = np.squeeze(labels == self.shape_names.index('desk'))
        # index_6 = np.squeeze(labels == self.shape_names.index('bathtub'))
        # index_7 = np.squeeze(labels == self.shape_names.index('sink'))
        # index_nega = ~(index_0 | index_1 | index_2 | index_3 | index_4 | index_5 | index_6 | index_7)
        index_nega = ~(index_0 | index_2 | index_3 | index_5)

        return {"chair": torch.from_numpy(data[index_0, :, :]).float(),
                "bed": torch.from_numpy(data[index_1, :, :]).float(),
                "sofa": torch.from_numpy(data[index_2, :, :]).float(),
                "table": torch.from_numpy(data[index_3, :, :]).float(),
                "bookshelf": torch.from_numpy(data[index_4, :, :]),
                "desk": torch.from_numpy(data[index_5, :, :]),
                # "bathtub": data[index_6, :, :],
                # "sink": data[index_7, :, :],
                "nega_data": torch.from_numpy(data[index_nega, :, :]).float()}

    def load_modelnet(self, data_root):
        validata_list = self.config.valida_model_list
        for item in self.config.model_list_validation:
            validata_list.remove(item)

        # validata_list.remove('chair')
        # validata_list.remove('bookshelf')
        # validata_list.remove('sofa')
        # validata_list.remove('table')
        nega_data = []
        for item in validata_list:
            data = np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", '%s_8192.npy' % item))
            nega_data.append(data)
        nega_data = np.vstack(nega_data) / self.voxel_size

        models = {}
        models["nega_data"] = torch.from_numpy(nega_data)

        for item in self.config.model_list:
            models[item] = torch.from_numpy(np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", '%s_8192.npy' % item)) / self.voxel_size)

        return models

        # chairs = np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", 'chair_8192.npy')) / self.voxel_size
        # sofas = np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", 'sofa_8192.npy')) / self.voxel_size
        # tables = np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", 'table_8192.npy')) / self.voxel_size
        # bookshelfs = np.load(os.path.join(os.environ['HOME'], "dataset/modelnet40_8192", 'bookshelf_8192.npy')) / self.voxel_size

        # for ii in range(5):
        #     model = tables[ii, :, :]
        #     pc_utils.write_ply(model, "%s_model_rotation.ply" % ii, text=True)

        # return {"chair": torch.from_numpy(chairs).float(),
        #         "sofa": torch.from_numpy(sofas).float(),
        #         "table": torch.from_numpy(tables).float(),
        #         "bookshelf": torch.from_numpy(bookshelfs),
                # "bathtub": data[index_6, :, :],
                # "sink": data[index_7, :, :],
                # "nega_data": torch.from_numpy(nega_data)}

    def __len__(self):
        return len(self.modelnet_data['nega_data'])

    def __getitem__(self, idx):

        data = {}
        for item in (self.config.model_list + ['nega_data']):
            models = self.modelnet_data[item]
            model = models[idx % len(models)]
            data[item] = pc_utils.resize_rotation(model, item)

        data['config'] = self.config
        # data['batchSize'] = self.config.batch_size_modelnet
        return data

        '''
        # read data
        chairs, sofas = self.modelnet_data['chair'], self.modelnet_data['sofa']
        tables, bookshelfs, nega_datas = self.modelnet_data['table'], self.modelnet_data['bookshelf'], self.modelnet_data['nega_data']

        chair = chairs[idx % len(chairs)]
        sofa = sofas[idx % len(sofas)]
        table = tables[idx % len(tables)]
        bookshelf = bookshelfs[idx % len(bookshelfs)]
        nega_data = nega_datas[idx % len(nega_datas)]

        # fitting model size to scans
        # feats = torch.ones(len(coords), 1)
        chair = pc_utils.resize_rotation(chair, 'chair')
        sofa = pc_utils.resize_rotation(sofa, 'sofa')
        table = pc_utils.resize_rotation(table, 'table')
        bookshelf = pc_utils.resize_rotation(bookshelf, 'bookshelf')
        nega_data = pc_utils.resize_rotation(nega_data, 'nega_data')

        return {'chair': chair, 'bookshelf': bookshelf, 'sofa': sofa, 'table': table, 'nega_data': nega_data, 'batchSize': self.config.batch_size_modelnet}
        '''

class S3DIS(Dataset):
    def __init__(self, test_area_idx, config, phase):
        '''
        k_n = 16  # KNN
        num_layers = 5  # Number of layers
        num_points = 40960  # Number of input points
        num_classes = 13  # Number of valid classes
        sub_grid_size = 0.04  # preprocess_parameter

        batch_size = 6  # batch_size during training
        val_batch_size = 20  # batch_size during validation and test
        train_steps = 500  # Number of steps per epochs
        val_steps = 100  # Number of validation steps per epoch

        sub_sampling_ratio = [4, 4, 4, 4, 2]  # sampling ratio of random sampling at each layer
        d_out = [16, 64, 128, 256, 512]  # feature dimension

        noise_init = 3.5  # noise initial parameter
        max_epoch = 100  # maximum epoch during training
        learning_rate = 1e-2  # initial learning rate
        lr_decays = {i: 0.95 for i in range(0, 500)}  # decay rate of learning rate

        train_sum_dir = 'train_log'
        saving = True
        saving_path = None
        '''
        self.config = config
        self.phase = phase
        self.name = 'S3DIS'
        self.voxel_size = config.voxel_size
        self.path = os.environ['HOME'] + "/dataset/S3DIS"
        self.label_to_names = {0: 'ceiling',
                               1: 'floor',
                               2: 'wall',
                               3: 'beam',
                               4: 'column',
                               5: 'window',
                               6: 'door',
                               7: 'table',
                               8: 'chair',
                               9: 'sofa',
                               10: 'bookcase',
                               11: 'board',
                               12: 'clutter'}
        self.num_classes = len(self.label_to_names)
        self.label_values = np.sort([k for k, v in self.label_to_names.items()])
        self.label_to_idx = {l: i for i, l in enumerate(self.label_values)}
        self.ignored_labels = np.array([])

        self.val_split = 'Area_' + str(test_area_idx)
        self.all_files = glob.glob(os.path.join(self.path, 'original_ply', '*.ply'))
        self.sub_grid_size = 0.04

        # Initiate containers
        self.val_proj = []
        self.val_labels = []
        self.possibility = {}
        self.min_possibility = {}
        self.input_trees = {'training': [], 'validation': []}
        self.input_colors = {'training': [], 'validation': []}
        self.input_labels = {'training': [], 'validation': []}
        self.input_names = {'training': [], 'validation': []}
        self.valid_formats = {'ascii': '', 'binary_big_endian': '>',
                         'binary_little_endian': '<'}

        self.ply_dtypes = dict([
            (b'int8', 'i1'),
            (b'char', 'i1'),
            (b'uint8', 'u1'),
            (b'uchar', 'u1'),
            (b'int16', 'i2'),
            (b'short', 'i2'),
            (b'uint16', 'u2'),
            (b'ushort', 'u2'),
            (b'int32', 'i4'),
            (b'int', 'i4'),
            (b'uint32', 'u4'),
            (b'uint', 'u4'),
            (b'float32', 'f4'),
            (b'float', 'f4'),
            (b'float64', 'f8'),
            (b'double', 'f8')
        ])
        self.load_sub_sampled_clouds(self.sub_grid_size)

    def parse_header(self, plyfile, ext):
        # Variables
        line = []
        properties = []
        num_points = None

        while b'end_header' not in line and line != b'':
            line = plyfile.readline()

            if b'element' in line:
                line = line.split()
                num_points = int(line[2])

            elif b'property' in line:
                line = line.split()
                properties.append((line[2].decode(), ext + self.ply_dtypes[line[1]]))

        return num_points, properties

    def parse_mesh_header(self, plyfile, ext):
        # Variables
        line = []
        vertex_properties = []
        num_points = None
        num_faces = None
        current_element = None

        while b'end_header' not in line and line != b'':
            line = plyfile.readline()

            # Find point element
            if b'element vertex' in line:
                current_element = 'vertex'
                line = line.split()
                num_points = int(line[2])

            elif b'element face' in line:
                current_element = 'face'
                line = line.split()
                num_faces = int(line[2])

            elif b'property' in line:
                if current_element == 'vertex':
                    line = line.split()
                    vertex_properties.append((line[2].decode(), ext + self.ply_dtypes[line[1]]))
                elif current_element == 'vertex':
                    if not line.startswith('property list uchar int'):
                        raise ValueError('Unsupported faces property : ' + line)

        return num_points, num_faces, vertex_properties

    def read_ply(self, filename, triangular_mesh=False):
        """
        Read ".ply" files

        Parameters
        ----------
        filename : string
            the name of the file to read.

        Returns
        -------
        result : array
            data stored in the file

        Examples
        --------
        Store data in file

        >>> points = np.random.rand(5, 3)
        >>> values = np.random.randint(2, size=10)
        >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values'])

        Read the file

        >>> data = read_ply('example.ply')
        >>> values = data['values']
        array([0, 0, 1, 1, 0])

        >>> points = np.vstack((data['x'], data['y'], data['z'])).T
        array([[ 0.466  0.595  0.324]
               [ 0.538  0.407  0.654]
               [ 0.850  0.018  0.988]
               [ 0.395  0.394  0.363]
               [ 0.873  0.996  0.092]])

        """

        with open(filename, 'rb') as plyfile:

            # Check if the file start with ply
            if b'ply' not in plyfile.readline():
                raise ValueError('The file does not start whith the word ply')

            # get binary_little/big or ascii
            fmt = plyfile.readline().split()[1].decode()
            if fmt == "ascii":
                raise ValueError('The file is not binary')

            # get extension for building the numpy dtypes
            ext = self.valid_formats[fmt]

            # PointCloud reader vs mesh reader
            if triangular_mesh:

                # Parse header
                num_points, num_faces, properties = self.parse_mesh_header(plyfile, ext)

                # Get point data
                vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points)

                # Get face data
                face_properties = [('k', ext + 'u1'),
                                   ('v1', ext + 'i4'),
                                   ('v2', ext + 'i4'),
                                   ('v3', ext + 'i4')]
                faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces)

                # Return vertex data and concatenated faces
                faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T
                data = [vertex_data, faces]

            else:
                # Parse header
                num_points, properties = self.parse_header(plyfile, ext)
                # Get data
                data = np.fromfile(plyfile, dtype=properties, count=num_points)

        return data

    def load_sub_sampled_clouds(self, sub_grid_size):
        tree_path = os.path.join(self.path, 'input_{:.3f}'.format(sub_grid_size))

        self.coords = []
        self.labels = []

        for i, file_path in enumerate(self.all_files):
            t0 = time.time()
            cloud_name = file_path.split('/')[-1][:-4]
            if self.val_split in cloud_name:
                cloud_split = 'val'
            else:
                cloud_split = 'train'

            if cloud_split != self.phase: continue

            # Name of the input files
            # kd_tree_file = os.path.join(tree_path, '{:s}_KDTree.pkl'.format(cloud_name))
            sub_ply_file = os.path.join(tree_path, '{:s}.ply'.format(cloud_name))

            data = self.read_ply(sub_ply_file)

            sub_points = np.vstack((data['x'], data['y'], data['z'])).T
            # sub_colors = np.vstack((data['red'], data['green'], data['blue'])).T
            sub_labels = data['class']

            self.coords.append(sub_points)
            self.labels.append(sub_labels)

            # pc_utils.write_ply(sub_points / 0.05, "%s_model.ply" % i, text=True)
            # pc_utils.write_ply_rgb(sub_points / 0.05, sub_colors * 255, "%s_model_colors.ply" % i, text=True)

            # import pdb
            # pdb.set_trace()

            # Read pkl with search tree
            # with open(kd_tree_file, 'rb') as f:
            #     search_tree = pickle.load(f)

            # self.input_trees[cloud_split] += [search_tree]
            # self.input_colors[cloud_split] += [sub_colors]
            # self.input_labels[cloud_split] += [sub_labels]
            # self.input_names[cloud_split] += [cloud_name]

            # size = sub_colors.shape[0] * 4 * 7
            # print('{:s} {:.1f} MB loaded in {:.1f}s'.format(kd_tree_file.split('/')[-1], size * 1e-6, time.time() - t0))

        # print('\nPreparing reprojected indices for testing')

        '''
        # Get validation and test reprojected indices
        for i, file_path in enumerate(self.all_files):
            t0 = time.time()
            cloud_name = file_path.split('/')[-1][:-4]

            # Validation projection and labels
            if self.val_split in cloud_name:
                proj_file = os.path.join(tree_path, '{:s}_proj.pkl'.format(cloud_name))
                with open(proj_file, 'rb') as f:
                    proj_idx, labels = pickle.load(f)
                self.val_proj += [proj_idx]
                self.val_labels += [labels]
                print('{:s} done in {:.1f}s'.format(cloud_name, time.time() - t0))
        '''

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        # _new_semantic.npy: 0~19, .npy: 1~20
        # path = os.path.join(self.scannet_root_dir, self.scannet_file_list[idx], self.scannet_file_list[idx]+"_new_semantic.npy")
        # path = os.path.join(self.scannet_root_dir, self.file_list[idx], self.file_list[idx]+".npy")
        # data = torch.from_numpy(np.load(path))
        # coords, feats, labels = data[:, :3], data[:, 3: 6], data[:, 9:]
        # coords, labels = data[:, :3], data[:, 9:]

        coords, labels = torch.from_numpy(self.coords[idx]), torch.from_numpy(self.labels[idx])
        labels = torch.cat((labels.unsqueeze(-1), labels.unsqueeze(-1)), dim=1)
        # feats = feats / 127.5 - 1
        # feats = feats * 0 + 1
        feats = torch.ones(len(coords), 1)
        coords = (coords - coords.mean(0)) / self.voxel_size

        # for ii in range(5):
        #     model = sofas[ii, :, :]
        # pc_utils.write_ply(coords, "%s_model_rotation.ply" % idx, text=True)

        if self.phase == 'train':
            coords = pc_utils.random_rotation(coords)
            batchSize = self.config.batch_size_scannet
        else:
            batchSize = 1

        packages = {'scannet': (coords, feats, labels, self.coords[idx]), 'batchSize': batchSize}
        return packages
