from spaghettini import quick_register

import dgl

from torch.utils.data.dataset import Dataset


@quick_register
class TensorDictDataset(Dataset):

    def __init__(self, tensors_dict):
        tensor_sizes = [tensor.size(0) for k, tensor in tensors_dict.items()]
        assert len(set(tensor_sizes)) == 1, "Tensor sizes don't match. "
        self.length = tensor_sizes[0]
        self.tensors_dict = tensors_dict

    def __getitem__(self, index):
        return {k: v[index] for k, v in self.tensors_dict.items()}

    def __len__(self):
        return self.length


@quick_register
class GraphTensorDictDataset(Dataset):

    def __init__(self, graphs, tensors_dict):
        tensor_sizes = [tensor.size(0) for k, tensor in tensors_dict.items()]
        assert len(set(tensor_sizes)) == 1, "Tensor sizes don't match. "
        self.length = tensor_sizes[0]
        # Add the features to the graph.
        for k, v in tensors_dict.items():
            graphs.ndata[k] = v

        self.unbatched_graphs = dgl.unbatch(graphs)

    def __getitem__(self, index):
        return self.unbatched_graphs[index]

    def __len__(self):
        return len(self.unbatched_graphs)
