import os
import os.path as osp

import numpy as np
import torch
from torch.utils.data import Dataset
from plyfile import PlyData
from pandas import DataFrame


SCANNET_COLOR_MAP = {
    0: (0., 0., 0.),
    1: (174., 199., 232.),
    2: (152., 223., 138.),
    3: (31., 119., 180.),
    4: (255., 187., 120.),
    5: (188., 189., 34.),
    6: (140., 86., 75.),
    7: (255., 152., 150.),
    8: (214., 39., 40.),
    9: (197., 176., 213.),
    10: (148., 103., 189.),
    11: (196., 156., 148.),
    12: (23., 190., 207.), # No 13
    14: (247., 182., 210.),
    15: (66., 188., 102.),
    16: (219., 219., 141.),
    17: (140., 57., 197.),
    18: (202., 185., 52.),
    19: (51., 176., 203.),
    20: (200., 54., 131.),
    21: (92., 193., 61.),
    22: (78., 71., 183.),
    23: (172., 114., 82.),
    24: (255., 127., 14.),
    25: (91., 163., 138.),
    26: (153., 98., 156.),
    27: (140., 153., 101.),
    28: (158., 218., 229.),
    29: (100., 125., 154.),
    30: (178., 127., 135.), # No 31
    32: (146., 111., 194.),
    33: (44., 160., 44.),
    34: (112., 128., 144.),
    35: (96., 207., 209.),
    36: (227., 119., 194.),
    37: (213., 92., 176.),
    38: (94., 106., 211.),
    39: (82., 84., 163.),
    40: (100., 85., 144.),
}
VALID_CLASS_LABELS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39)
VALID_CLASS_NAMES = (
    'wall',
    'floor',
    'cabinet',
    'bed',
    'chair',
    'sofa',
    'table',
    'door',
    'window',
    'bookshelf',
    'picture',
    'counter',
    'desk',
    'curtain',
    'refrigerator',
    'shower curtain',
    'toilet',
    'sink',
    'bathtub',
    'otherfurniture'
)
SCANNET_SEM_SEG_LABEL_MAP = {
    0: 255, 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5,
    7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 255,
    14: 12, 15: 255, 16: 13, 17: 255, 18: 255, 19: 255,
    20: 255, 21: 255, 22: 255, 23: 255, 24: 14, 25: 255,
    26: 255, 27: 255, 28: 15, 29: 255, 30: 255, 31: 255,
    32: 255, 33: 16, 34: 17, 35: 255, 36: 18, 37: 255,
    38: 255, 39: 19, 40: 255
}


def read_ply(filename):
    with open(osp.join(filename), 'rb') as f:
        plydata = PlyData.read(f)
    assert plydata.elements
    data = DataFrame(plydata.elements[0].data).values
    return data


class ScannetDatasetBase(Dataset):
    USE_RGB = False
    IN_CHANNELS = 1
    CLASS_LABELS = None
    SPLIT_FILES = {
        'train': 'scannetv2_train.txt',
        'val': 'scannetv2_val.txt',
        'trainval': 'scannetv2_trainval.txt',
        'test': 'scannetv2_test.txt',
        'overfit': 'scannetv2_overfit.txt'
    }

    def __init__(self, split, transform=None, config=None):
        assert self.CLASS_LABELS is not None
        assert split in ['train', 'val', 'trainval', 'test', 'overfit']
        super(ScannetDatasetBase, self).__init__()

        self.split = split
        self.voxel_size = config.voxel_size
        self.data_path = config.scannet_path
        self.limit_numpoints = config.limit_numpoints
        self.transform = transform
        self.cache = {}
        self.cache_data = config.cache_data
        self.split_file = self.SPLIT_FILES[split]
        self.ignore_class_labels = tuple(set(range(41)) - set(self.CLASS_LABELS))
        self.ignore_label = config.ignore_label
        self.label_map = {}
        for k in range(41):
            if k in self.ignore_class_labels:
                self.label_map[k] = self.ignore_label
            else:
                self.label_map[k] = self.CLASS_LABELS.index(k)
        with open(osp.join(self.data_path, 'meta_data', self.split_file), 'r') as f:
            filenames = f.read().splitlines()
        stem = 'test' if split == 'test' else 'train'
        self.filenames = [osp.join(config.scannet_path, 'scannet_processed', stem, f'{filename}.ply') for filename in filenames]

    def __len__(self):
        return len(self.filenames)

    def get_classnames(self):
        classnames = {}
        for class_id in self.CLASS_LABELS:
            classnames[self.label_map[class_id]] = VALID_CLASS_NAMES[VALID_CLASS_LABELS.index(class_id)]
        return classnames

    def get_colormaps(self):
        colormaps = {}
        for class_id in self.CLASS_LABELS:
            colormaps[self.label_map[class_id]] = SCANNET_COLOR_MAP[class_id]
        return colormaps

    def get_label_unmap(self):
        label_unmap = {}
        for k, v in self.label_map:
            label_unmap[v] = k
        return label_unmap


class ScannetRGBDataset(ScannetDatasetBase):
    USE_RGB = True
    IN_CHANNELS = 3
    CLASS_LABELS = VALID_CLASS_LABELS
    NUM_CLASSES = 20

    def __getitem__(self, idx):
        data = self._load_data(idx)
        coords, feats, labels = self.get_cfl_from_data(data)
        num_points = len(coords)
        
        is_train = 'train' in self.split
        if is_train and self.limit_numpoints > 0 and num_points > self.limit_numpoints:
            inds = np.random.choice(num_points, self.limit_numpoints, replace=False)
            coords = coords[inds]
            feats = feats[inds]
            labels = labels[inds]

        return (
            torch.from_numpy(coords),
            torch.from_numpy(feats),
            torch.from_numpy(labels)
        )

    def _load_data(self, idx):
        filename = self.filenames[idx]
        if self.cache_data and filename in self.cache:
            data = self.cache[filename]
        else:
            data = read_ply(filename)
        if self.cache_data:
            self.cache[filename] = data
        return data

    def get_cfl_from_data(self, data):
        xyz, rgb, labels = data[:, :3], data[:, 3:6], data[:, -2]
        xyz /= self.voxel_size
        if self.transform is not None:
            xyz, rgb, labels = self.transform(xyz, rgb, labels)
        labels = np.array([self.label_map[x] for x in labels])
        return xyz.astype(np.float32), rgb.astype(np.float32), labels.astype(np.int64)

class TinyScannetRGBDataset(ScannetRGBDataset):
    SPLIT_FILES = {
        'train': 'tiny_scannetv2_train.txt',
        'val': 'scannetv2_val.txt',
        'trainval': 'scannetv2_trainval.txt',
        'test': 'scannetv2_test.txt',
        'overfit': 'scannetv2_overfit.txt'
    }
    
class TopkScannetRGBDataset(ScannetRGBDataset):
    SPLIT_FILES = {
        'train': 'topk_scannetv2_train.txt',
        'val': 'scannetv2_val.txt',
        'trainval': 'scannetv2_trainval.txt',
        'test': 'scannetv2_test.txt',
        'overfit': 'scannetv2_overfit.txt'
    }

class ScannetRGBXYZDataset(ScannetRGBDataset):
    IN_CHANNELS = 6

    def get_cfl_from_data(self, data):
        xyz, rgb, labels = data[:, :3], data[:, 3:6], data[:, -2]
        xyz /= self.voxel_size
        if self.transform is not None:
            xyz, rgb, labels = self.transform(xyz, rgb, labels)
        center_xyz = xyz.mean(0, keepdims=True)
        norm_xyz = xyz - center_xyz # centering
        denom = np.abs(norm_xyz).max(0, keepdims=True)
        norm_xyz = norm_xyz / 2 * denom
        feats = np.concatenate([rgb, norm_xyz], axis=1)
        labels = np.array([self.label_map[x] for x in labels])
        return xyz.astype(np.float32), feats.astype(np.float32), labels.astype(np.int64)


# -------------------------
#          Misc
# -------------------------

def find_sparse_classes():
    from tqdm import tqdm
    from eval import SimpleConfig
    
    cfg = SimpleConfig()
    dset = ScannetRGBDataset("trainval", config=cfg)
    cnames = dset.get_classnames()
    
    total_counts = torch.zeros(20, dtype=torch.int64)
    for i, (coords, feats, labels) in enumerate(tqdm(dset)):
        unique_labels, counts = torch.unique(labels, return_counts=True)
        is_valid = torch.where(unique_labels != dset.ignore_label)
        total_counts.index_add_(0, unique_labels[is_valid], counts[is_valid])
    sorted_counts, sorted_labels = torch.sort(total_counts)
    
    for label, count in zip(sorted_labels, sorted_counts):
        print(f"{label.item()}. {cnames[label.item()]}: {count}")
        
    # 17. sink: 461593              ----
    # 16. toilet: 549237                |
    # 18. bathtub: 571389               |
    # 15. shower curtain: 595718        | 7 sparse classes
    # 10. picture: 885819               |
    # 11. counter: 889212               |
    # 14. refrigerator: 945132      ----
    # 12. desk: 3330654
    # 13. curtain: 3669519
    # 5. sofa: 4190918
    # 9. bookshelf: 4418153
    # 3. bed: 4677775
    # 6. table: 6399954
    # 19. otherfurniture: 6507146
    # 8. window: 6811178
    # 2. cabinet: 7046686
    # 7. door: 7895117
    # 4. chair: 13105750
    # 1. floor: 40913240
    # 0. wall: 52758588
    
def assign_sparse_categories():
    import json
    from tqdm import tqdm
    from eval import SimpleConfig
    
    cfg = SimpleConfig()
    dset = ScannetRGBDataset("trainval", config=cfg)
    cnames = dset.get_classnames()
    
    results = {
        17: [],
        16: [],
        18: [],
        15: [],
        10: [],
        11: [],
        14: [],
        -1: []
    }
    sclasses = [17, 16, 18, 15, 10, 11, 14]
    for i, (coords, feats, labels) in enumerate(tqdm(dset)):
        fname = dset.filenames[i].split("/")[-1].split(".")[0]
        unique_labels, counts = torch.unique(labels, return_counts=True)
        is_valid = torch.where(unique_labels != dset.ignore_label)
        unique_labels = unique_labels[is_valid]
        counts = counts[is_valid]
        
        scounts = []
        tmp_max = counts.max().item()
        has_sparse = False
        for sclass in sclasses:
            mask = unique_labels == sclass
            if sum(mask):
                scounts.append(counts[mask].item())
                has_sparse = True
            else:
                scounts.append(tmp_max)
        
        if has_sparse:
            idx = np.argmin(scounts)
            results[sclasses[idx]].append(fname)
        else:
            results[-1].append(fname)
        
    with open("./sparse_assign_results.json", "w") as f:
        json.dump(results, f)
        
def make_tiny_scannet_splits_random():
    import json
    from tqdm import tqdm
    from eval import SimpleConfig
    
    ratio = 0.125
    
    cfg = SimpleConfig()
    train_dset = ScannetRGBDataset("train", config=cfg)
    train_fnames = [fname_.split("/")[-1].split(".")[0] for fname_ in train_dset.filenames]
    num_train = len(train_fnames)
    num_train_tiny = int(ratio * num_train)
    tiny_train_fnames = np.random.choice(train_fnames, num_train_tiny, replace=False)
    
    with open("tiny_scannetv2_train.txt", "w") as f:
        for fname in tiny_train_fnames:
            f.write(fname + "\n")
            
def find_topk_scannet():
    from tqdm import tqdm
    from eval import SimpleConfig
    
    K = 16
    cfg = SimpleConfig()
    dset = ScannetRGBDataset("train", config=cfg)
    num_points = []
    for coords, feats, labels in tqdm(dset):
        num_points.append(len(coords))
    num_points = np.array(num_points)
    topk_indices = num_points.argsort()[::-1][:K]
    for idx in topk_indices:
        print(f"{dset.filenames[idx]}: {num_points[idx]}")

def test_tiny_scannet():
    from eval import SimpleConfig
    
    cfg = SimpleConfig()
    tiny_dset = TinyScannetRGBDataset("train", config=cfg)
    print(len(tiny_dset)) # 300
    
def find_sparse_classes_tiny():
    from tqdm import tqdm
    from eval import SimpleConfig
    
    cfg = SimpleConfig()
    dset = TinyScannetRGBDataset("train", config=cfg)
    cnames = dset.get_classnames()
    
    total_counts = torch.zeros(20, dtype=torch.int64)
    for i, (coords, feats, labels) in enumerate(tqdm(dset)):
        unique_labels, counts = torch.unique(labels, return_counts=True)
        is_valid = torch.where(unique_labels != dset.ignore_label)
        total_counts.index_add_(0, unique_labels[is_valid], counts[is_valid])
    sorted_counts, sorted_labels = torch.sort(total_counts)
    
    for label, count in zip(sorted_labels, sorted_counts):
        print(f"{label.item()}. {cnames[label.item()]}: {count}")

if __name__ == '__main__':
    # make_tiny_scannet_splits_random()
    # test_tiny_scannet()
    # find_sparse_classes_tiny()
    find_topk_scannet()