import torch
from pathlib import Path
from torch.utils import data
from torchvision.datasets import Omniglot
from torchvision import transforms


def add_dataset_args(parser):
    parser.add_argument('--data_dir', type=str)


class Dataset(data.Dataset):
    def __init__(self, args, is_train=True) -> None:
        super().__init__()

        self.args = args
        self.is_train = is_train
        self.transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor()
        ])
        self.data_dir = Path(args.data_dir)

        self.data = Omniglot(
            str(self.data_dir),
            background=self.is_train,
            download=True
        )
        # self.data = data.data.numpy()
        self.features = None  # data.targets    

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, _ = self.data[idx]
        x = self.transform(x)
        x[x >= 0.5] = 1.
        x[x < 0.5] = 0.
        x = 1 - x
        return x

