#! -*- coding: utf-8
import os.path as path
import typing

import numpy as np
import torch

CLASSES = ["30", "31", "32", "33", "34", "35", "36", "37", "38", "39",
           "41", "42", "43", "44", "45", "46", "47", "48", "49", "4a", "4b", "4c", "4d", "4e", "4f",
           "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "5a",
           "61", "62", "63", "64", "65", "66", "67", "68", "69", "6a", "6b", "6c", "6d", "6e", "6f",
           "70", "71", "72", "73", "74", "75", "76", "77", "78", "79", "7a",]

__all__ = ["FEMNIST"]


class FEMNIST(torch.utils.data.Dataset):
    def __init__(self, datadir: str, train: bool = True):
        self.datafile = path.join(datadir, "by_class.train.npz" if train else "by_class.test.npz")
        datas = np.load(self.datafile, allow_pickle=True)
        self.targets = torch.tensor(datas["targets"])
        self.datas = torch.tensor(datas["datas"])

    def __len__(self): return len(self.targets)

    def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor,
                                                    torch.Tensor]:
        y = self.targets[idx]
        x = self.datas[idx]
        return x, y

