import torch
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2


def pico_mnist_dataset(image_dir='./data/'):
    transform = v2.Compose([
        v2.Resize((224, 224)),
        v2.PILToTensor(),
        v2.ToDtype(torch.float32),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = ImageFolder(root=image_dir + 'train/', transform=transform)
    test_dataset = ImageFolder(root=image_dir + 'test/', transform=transform)

    return train_dataset, test_dataset
