import os
import os.path as osp

import numpy as np
import torch
from torch.utils.data import Dataset

from .scannet import read_ply


CLASSES = [
    'ceiling', 'floor', 'wall', 'beam', 'column',
    'window', 'door', 'chair', 'table', 'bookcase',
    'sofa', 'board', 'clutter',
]


def read_txt(path):
  """Read txt file into lines.
  """
  with open(path) as f:
    lines = f.readlines()
  lines = [x.strip() for x in lines]
  return lines


class S3DISArea5DatasetBase(Dataset):
    USE_RGB = True
    IN_CHANNELS = None
    NUM_CLASSES = 13
    SPLIT_FILES = {
        'train': ['area1.txt', 'area2.txt', 'area3.txt', 'area4.txt', 'area6.txt'],
        'val': ['area5.txt'],
        'test': ['area5.txt']
    }

    def __init__(self, split, transform=None, config=None):
        assert self.IN_CHANNELS is not None
        assert split in ['train', 'val', 'test']
        super(S3DISArea5DatasetBase, self).__init__()

        self.split = split
        self.voxel_size = config.voxel_size
        self.data_path = config.s3dis_path
        self.limit_numpoints = config.limit_numpoints
        self.transform = transform
        self.cache = {}
        self.cache_data = config.cache_data
        self.split_files = self.SPLIT_FILES[split]
        filenames = []
        for split_file in self.split_files:
            filenames += read_txt(osp.join(self.data_path, 'meta_data', split_file))
        self.filenames = [osp.join(self.data_path, 's3dis_processed', fname) for fname in filenames]

    def __len__(self):
        return len(self.filenames)

    def get_classnames(self):
        return CLASSES

    def __getitem__(self, idx):
        data = self._load_data(idx)
        coords, feats, labels = self.get_cfl_from_data(data)
        num_points = len(coords)

        is_train = 'train' in self.split
        if is_train and self.limit_numpoints > 0 and num_points > self.limit_numpoints:
            inds = np.random.choice(num_points, self.limit_numpoints, replace=False)
            coords = coords[inds]
            feats = feats[inds]
            labels = labels[inds]

        return (
            torch.from_numpy(coords),
            torch.from_numpy(feats),
            torch.from_numpy(labels)
        )

    def get_cfl_from_data(self, data):
        raise NotImplementedError

    def _load_data(self, idx):
        filename = self.filenames[idx]
        if self.cache_data and filename in self.cache:
            data = self.cache[filename]
        else:
            data = read_ply(filename)
        if self.cache_data:
            self.cache[filename] = data
        return data


class S3DISArea5RGBDataset(S3DISArea5DatasetBase):
    USE_RGB = True
    IN_CHANNELS = 3

    def __init__(self, split, transform=None, config=None):
        super(S3DISArea5RGBDataset, self).__init__(split, transform, config)

    def get_cfl_from_data(self, data):
        xyz, rgb, label = data[:, :3], data[:, 3:6], data[:, 6]
        xyz /= self.voxel_size
        if self.transform is not None:
            xyz, rgb, label = self.transform(xyz, rgb, label)
        return xyz.astype(np.float32), rgb.astype(np.float32), label.astype(np.int64)
    
    
class S3DISArea5RGB1cmDataset(S3DISArea5RGBDataset):
    def __init__(self, split, transform=None, config=None):
        assert self.IN_CHANNELS is not None
        assert split in ['train', 'val', 'test']
        super(S3DISArea5DatasetBase, self).__init__()
        
        self.split = split
        self.voxel_size = config.voxel_size
        self.data_path = config.s3dis_path
        self.limit_numpoints = config.limit_numpoints
        self.transform = transform
        self.cache = {}
        self.cache_data = config.cache_data
        self.split_files = self.SPLIT_FILES[split]
        filenames = []
        for split_file in self.split_files:
            filenames += read_txt(osp.join(self.data_path, 'meta_data', split_file))
        self.filenames = [osp.join(self.data_path, 's3dis_processed_1cm', fname) for fname in filenames]


class S3DISArea5RGBXYZDataset(S3DISArea5DatasetBase):
    USE_RGB = True
    IN_CHANNELS = 6

    def __init__(self, split, transform=None, config=None):
        super(S3DISArea5RGBXYZDataset, self).__init__(split, transform, config)

    def get_cfl_from_data(self, data):
        xyz, rgb, label = data[:, :3], data[:, 3:6], data[:, 6]
        xyz /= self.voxel_size
        if self.transform is not None:
            xyz, rgb, label = self.transform(xyz, rgb, label)
        center_xyz = xyz.mean(0, keepdims=True)
        norm_xyz = xyz - center_xyz # centering
        denom = np.abs(norm_xyz).max(0, keepdims=True)
        norm_xyz = norm_xyz / 2 * denom
        feats = np.concatenate([rgb, norm_xyz], axis=1)
        return xyz.astype(np.float32), feats.astype(np.float32), label.astype(np.int64)


class S3DISArea5RGBXYZ1cmDataset(S3DISArea5RGBXYZDataset):
    def __init__(self, split, transform=None, config=None):
        assert self.IN_CHANNELS is not None
        assert split in ['train', 'val', 'test']
        super(S3DISArea5DatasetBase, self).__init__()

        self.split = split
        self.voxel_size = config.voxel_size
        self.data_path = config.s3dis_path
        self.limit_numpoints = config.limit_numpoints
        self.transform = transform
        self.cache = {}
        self.cache_data = config.cache_data
        self.split_files = self.SPLIT_FILES[split]
        filenames = []
        for split_file in self.split_files:
            filenames += read_txt(osp.join(self.data_path, 'meta_data', split_file))
        self.filenames = [osp.join(self.data_path, 's3dis_processed_1cm', fname) for fname in filenames]


if __name__ == '__main__':
    from tqdm import tqdm

    class FakeConfig:
        def __init__(self, path):
            self.s3dis_path = path
            self.cache_data = False
            self.ignore_label = 255
            self.limit_numpoints = -1

    cfg = FakeConfig('/root/data/S3DIS')
    dset = S3DISArea5RGBXYZDataset('train', config=cfg)

    record_num_points = []
    for coords, feats, labels in tqdm(dset):
        record_num_points.append(len(coords))

    record_num_points = np.array(record_num_points)
    np.save('record_num_points_s3dis.npy', record_num_points)
