'''
Author: 
Email:
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 16:30:51
Description: 
    - grid data loader
    - tree data loader
    - grammar data loader
'''

import numpy as np

import torch.utils.data as torch_data
from utils import COLOR
from structure import Tree


class ObjectGridDataset(torch_data.Dataset):
    def __init__(self, path, oracle_path, position_scale, train_w_gt):
        self.path = path
        # a dict with index as the key
        self.data = np.load(self.path)

        if train_w_gt:
            oracle_data = np.load(oracle_path) # [B, 10, 6]
            self.data = np.concatenate([self.data, oracle_data], axis=0)
        else:
            oracle_data = []
        self.num_examples = len(self.data)

        # normalize
        self.data[:, :, 0] /= position_scale
        self.data[:, :, 1] /= position_scale
        self.data[:, :, 2] /= (2*np.pi)

        print(COLOR.GREEN+'Object Pose Grid Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of original data:', self.num_examples-len(oracle_data))
        print('\tNumber of repeated oracle data:', len(oracle_data))
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        one_data = self.data[index]
        return one_data


class ObjectTreeDataset(torch_data.Dataset):
    def __init__(self, path, oracle_path, position_scale, train_w_gt):
        self.path = path
        # a dict with index as the key
        self.data = np.load(self.path, allow_pickle=True).tolist() # a list of dict
        if train_w_gt:
            oracle_data = np.load(oracle_path, allow_pickle=True).tolist()
            self.data = self.data + oracle_data # list concatnate
        else:
            oracle_data = []

        self.num_examples = len(self.data)
        self.trees = []
        for d_i in range(self.num_examples):
            tree = Tree(self.data[d_i], position_scale)
            self.trees.append(tree)

        print(COLOR.GREEN+'Object Pose Tree Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of original data:', self.num_examples-len(oracle_data))
        print('\tNumber of repeated oracle data:', len(oracle_data))
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        tree = self.trees[index]
        return tree

    # without this function, torch.utils.data.DataLoader cannnot deal with customized data type
    def collate_fn(self, batch):
        return batch


class ObjectGrammarDataset(torch_data.Dataset):
    def __init__(self, path, oracle_path, position_scale, train_w_gt):
        self.path = path
        # a dict with index as the key
        self.data = np.load(self.path) # [N1, 25, 15]

        if train_w_gt:
            oracle_data = np.load(oracle_path) # [N2, 25, 15]
            self.data = np.concatenate([self.data, oracle_data], axis=0)
        else:
            oracle_data = []
        
        self.data = self.data.transpose(0, 2, 1)  # [N1+N2, 15, 25]
        self.num_examples = len(self.data)

        # normalize
        self.data[:, 9, :] /= position_scale
        self.data[:, 10, :] /= position_scale
        self.data[:, 11, :] /= (2*np.pi)

        print(COLOR.GREEN+'Object Pose Grammar Dataset:')
        print('\tFilename:', self.path)
        print('\tNumber of original data:', self.num_examples-len(oracle_data))
        print('\tNumber of repeated oracle data:', len(oracle_data))
        print(COLOR.WHITE+'')

    def __len__(self):
        return self.num_examples

    def __getitem__(self, index):
        one_data = self.data[index]
        return one_data
