#! -*- coding: utf-8
import os.path as path
import typing

import numpy as np
import torch

__all__ = ["Digit5"]


class Digit5(torch.utils.data.Dataset):
    def __init__(self, datadir: str):
        self.datafile = path.join(datadir, "digit-5.npz")
        datas = np.load(self.datafile, allow_pickle=True)
        self.targets = torch.tensor(datas["targets"])
        self.sources = np.array(datas["sources"])
        self.datas = torch.tensor(datas["datas"])

    def __len__(self): return len(self.targets)

    def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor,
                                                    torch.Tensor]:
        y = self.targets[idx]
        x = self.datas[idx]
        return x, y
