from typing import Dict, Mapping, Union

import torch

LabelType = Union[int, str]


class DictDataset(torch.utils.data.Dataset):
    """A dataset of labeled tensors.

    Args:
        tensors: Mapping from label to tensor.
    """
    def __init__(self, tensors: Mapping[LabelType, torch.Tensor]):
        assert len(tensors) > 0
        self.tensors = tensors
        self.length = len(next(iter(self.tensors.values())))
        for key, tensor in tensors.items():
            assert len(tensor) == self.length, f"All tensors must be the same length, but {key} is not."

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, index: int) -> Dict[LabelType, torch.Tensor]:
        return {key: tensor[index] for (key, tensor) in self.tensors.items()}
