import os
import sys
import numpy as np
import pickle
import collections
from torch.utils.data import Dataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)

import ipdb
st = ipdb.set_trace


class ScannetDetectionDataset(Dataset):
       
    def __init__(self, split_set='train'):

        self.data_path = os.path.join('./dataset/language_grounding/scans', 'scannet_train_detection_data')
        all_scan_names = list(set([os.path.basename(x)[0:12] \
            for x in os.listdir(self.data_path) if x.startswith('scene')]))
        if split_set=='all':            
            self.scan_names = all_scan_names
        elif split_set in ['train', 'val', 'test']:
            split_filenames = os.path.join(ROOT_DIR, 'scannet/meta_data',
                'scannetv2_{}.txt'.format(split_set))
            with open(split_filenames, 'r') as f:
                self.scan_names = f.read().splitlines()   
            # remove unavailiable scans
            num_scans = len(self.scan_names)
            self.scan_names = [sname for sname in self.scan_names \
                if sname in all_scan_names]
            print('kept {} scans out of {}'.format(len(self.scan_names), num_scans))
            num_scans = len(self.scan_names)
        else:
            print('illegal split name')
            return

        self.sizes = {}
        self.all_sizes = []
        self.uniq_classes = set()
        self.max_obj = 0

       
    def __len__(self):
        return len(self.scan_names)

    def __getitem__(self, idx):
        
        scan_name = self.scan_names[idx]        
        # mesh_vertices = np.load(os.path.join(self.data_path, scan_name)+'_vert.npy')
        # instance_labels = np.load(os.path.join(self.data_path, scan_name)+'_ins_label.npy')
        # semantic_labels = np.load(os.path.join(self.data_path, scan_name)+'_sem_label.npy')
        instance_bboxes = np.load(os.path.join(self.data_path, scan_name)+'_bbox.npy')
        self.max_obj = max(self.max_obj, instance_bboxes.shape[0])
        for box in instance_bboxes:
            box_label = int(box[-1])
            self.uniq_classes.add(box_label)
            self.all_sizes.append(box[3:6])
            if box_label not in self.sizes.keys():
                self.sizes[box_label] = np.expand_dims(box[3:6], axis=0)
            else:
                # print(self.sizes[box_label].shape)
                self.sizes[box_label] = np.vstack((self.sizes[box_label], box[3:6]))

if __name__=="__main__":
    dset = ScannetDetectionDataset(split_set='train')
    for i in range(len(dset)):
        dset[i]
        
    print(dset.max_obj)

    st()
    
    # mean_sizes = {}
    # for key in dset.sizes.keys():
    #     mean_sizes[key] = np.mean(dset.sizes[key], axis=0)

    # a_file = open("meta_data/mean_sizes_485_classes.pkl", "wb")
    # pickle.dump(mean_sizes, a_file)
    # a_file.close()

    # od = collections.OrderedDict(sorted(mean_sizes.items()))

    # mean_sizes_arr = np.zeros((len(mean_sizes), 3))
    # for idx, key in enumerate(od.keys()):
    #     mean_sizes_arr[idx] = od[key]

    # a_file = open("meta_data/mean_sizes_485_classes_ordered.npz", "wb")
    # pickle.dump(mean_sizes_arr, a_file)
    # a_file.close()

    # all_sizes = np.array(dset.all_sizes)
    # print(np.mean(all_sizes, axis=0))
    # st()