'''
Author: 
Email:
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 18:36:22
Description: 
    The data loader will build a tree structure for each data sample.
'''

import numpy as np

import torch.utils.data as torch_data

from utils import COLOR
from structure import Tree


class VehicleSceneTreeDataset(torch_data.Dataset):
    """ This dataset load the dict-style tree data and build the trees with predefined structure.
        Note that the tree sample consumes too many stack, therefore the number of worker can only be set to 0.
    """
    def __init__(self, path, position_scale=1):
        self.path = path
        # a dict with index as the key
        self.data = np.load(self.path, allow_pickle=True) #[0:154]
        self.num_examples = len(self.data)

        self.trees = []
        for d_i in range(self.num_examples):
            tree = Tree(self.data[d_i], position_scale)
            self.trees.append(tree)

        print(COLOR.GREEN+'Vehicle Pose Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of data:', len(self.trees))
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        tree = self.trees[index]
        return tree
    
    # without this function, torch.utils.data.DataLoader cannnot deal with customized data type
    def collate_fn(self, batch):
        return batch


class VehiclePoseGridDataset(torch_data.Dataset):
    """ This dataset is simple because all samples has the same data dimension.
        We firstly divide the plane into 64 grids, each grid has maximal 2 vehicles.
    """
    def __init__(self, path='./data/vehicle_pose_grid.npy'):
        self.path = path
        # a dict with index as the key
        data = np.load(self.path, allow_pickle=True).item()
        self.data = data['data']
        self.grid_center = data['grid_center']
        self.num_examples = self.data.shape[0]

        print(COLOR.GREEN+'Grid Vehicle Pose Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of points:', self.num_examples)
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        one_data = self.data[index]
        position = one_data[:, :, 0:3]
        type = one_data[:, :, 3]
        return type, position
