import torch
from torchvision.datasets import LFWPairs
from torchvision.transforms import v2


def lfw_pairs_dataset(subset='train'):
    transform = v2.Compose([
        v2.Resize((224, 224)),
        v2.PILToTensor(),
        v2.ToDtype(torch.float32),
        v2.Normalize(mean=[0, 0, 0], std=[255.0, 255.0, 255.0]),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = LFWPairs(root='./data/', split=subset, transform=transform, download=True)
    return dataset

