import os
import os.path as osp
from typing import Sequence

import numpy as np
import torch
from torch.utils.data import Dataset


label_name_mapping = {
    0: 'unlabeled',
    1: 'outlier',
    10: 'car',
    11: 'bicycle',
    13: 'bus',
    15: 'motorcycle',
    16: 'on-rails',
    18: 'truck',
    20: 'other-vehicle',
    30: 'person',
    31: 'bicyclist',
    32: 'motorcyclist',
    40: 'road',
    44: 'parking',
    48: 'sidewalk',
    49: 'other-ground',
    50: 'building',
    51: 'fence',
    52: 'other-structure',
    60: 'lane-marking',
    70: 'vegetation',
    71: 'trunk',
    72: 'terrain',
    80: 'pole',
    81: 'traffic-sign',
    99: 'other-object',
    252: 'moving-car',
    253: 'moving-bicyclist',
    254: 'moving-person',
    255: 'moving-motorcyclist',
    256: 'moving-on-rails',
    257: 'moving-bus',
    258: 'moving-truck',
    259: 'moving-other-vehicle'
}

kept_labels = [
    'road', 'sidewalk', 'parking', 'other-ground', 'building', 'car', 'truck',
    'bicycle', 'motorcycle', 'other-vehicle', 'vegetation', 'trunk', 'terrain',
    'person', 'bicyclist', 'motorcyclist', 'fence', 'pole', 'traffic-sign'
]


class SemanticKITTIDataset(Dataset):
    USE_RGB = False
    IN_CHANNELS = 4
    NUM_CLASSES = 19
    SEQUENCES = {
        'train': ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'],
        'val': ['08'],
        'trainval': ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10'],
        'test': ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21']
    }

    def __init__(self, split, transform=None, config=None):
        assert split in ['train', 'val', 'trainval', 'test']
        super(SemanticKITTIDataset, self).__init__()

        self.split = split
        self.voxel_size = config.voxel_size
        self.data_path = config.kitti_path
        self.transform = transform
        self.limit_numpoints = config.limit_numpoints
        self.files = []
        self.seqs = self.SEQUENCES[split]
        for seq in self.seqs:
            seq_files = sorted(
                os.listdir(osp.join(self.data_path, seq, 'velodyne')))
            seq_files = [
                osp.join(self.data_path, seq, 'velodyne', x) for x in seq_files
            ]
            self.files.extend(seq_files)

        reverse_label_name_mapping = {}
        self.label_map = np.zeros(260)
        cnt = 0
        for label_id in label_name_mapping:
            if label_id > 250:
                if label_name_mapping[label_id].replace('moving-',
                                                        '') in kept_labels:
                    self.label_map[label_id] = reverse_label_name_mapping[
                        label_name_mapping[label_id].replace('moving-', '')]
                else:
                    self.label_map[label_id] = 255
            elif label_id == 0:
                self.label_map[label_id] = 255
            else:
                if label_name_mapping[label_id] in kept_labels:
                    self.label_map[label_id] = cnt
                    reverse_label_name_mapping[
                        label_name_mapping[label_id]] = cnt
                    cnt += 1
                else:
                    self.label_map[label_id] = 255

        self.reverse_label_name_mapping = reverse_label_name_mapping
        self.num_classes = cnt
        self.angle = 0.0

    def get_classnames(self):
        classnames = {}
        for name, id in self.reverse_label_name_mapping.items():
            classnames[id] = name
        return classnames

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        with open(self.files[index], 'rb') as b:
            block_ = np.fromfile(b, dtype=np.float32).reshape(-1, 4)
        block = np.zeros_like(block_)

        is_train = 'train' in self.split
        if is_train:
            theta = np.random.uniform(0, 2 * np.pi)
            scale_factor = np.random.uniform(0.95, 1.05)
            rot_mat = np.array([[np.cos(theta), np.sin(theta), 0],
                                [-np.sin(theta),np.cos(theta), 0],
                                [0, 0, 1]])

            block[:, :3] = np.dot(block_[:, :3], rot_mat) * scale_factor
        else:
            theta = self.angle
            transform_mat = np.array([[np.cos(theta),
                                       np.sin(theta), 0],
                                      [-np.sin(theta),
                                       np.cos(theta), 0], [0, 0, 1]])
            block[...] = block_[...]
            block[:, :3] = np.dot(block[:, :3], transform_mat)

        block[:, 3] = block_[:, 3]
        pc = block[:, :3].copy()

        label_file = self.files[index].replace('velodyne', 'labels').replace(
            '.bin', '.label')
        if os.path.exists(label_file):
            with open(label_file, 'rb') as a:
                all_labels = np.fromfile(a, dtype=np.int32).reshape(-1)
        else:
            all_labels = np.zeros(pc.shape[0]).astype(np.int32)

        labels = self.label_map[all_labels & 0xFFFF].astype(np.int64)
        feat = block

        num_points = len(pc)
        if is_train and self.limit_numpoints > 0 and num_points > self.limit_numpoints:
            inds = np.random.choice(num_points, self.limit_numpoints, replace=False)
            pc = pc[inds]
            feat = feat[inds]
            labels = labels[inds]
        pc /= self.voxel_size
        if self.transform is not None:
            pc, feat, labels = self.transform(pc, feat, labels)

        return torch.from_numpy(pc), torch.from_numpy(feat), torch.from_numpy(labels)


if __name__ == '__main__':
    from tqdm import tqdm

    class FakeConfig:
        def __init__(self, path):
            self.kitti_path = path
            self.voxel_size = 0.05
            self.limit_numpoints = -1

    cfg = FakeConfig('/root/data/SemanticKITTI/sequences')
    dset = SemanticKITTIDataset('train', config=cfg)
    print(f'# classes: {dset.num_classes}')

    for coords, feats, labels in tqdm(dset):
        size = coords.max(0)[0] - coords.min(0)[0]
        print(size)
