from torchvision.transforms.v2 import Resize

transform = Resize((256, 256))
datamodule = MVTec(transform=transform)

datamodule.prepare_data()
datamodule.setup()

datamodule.train_transform
# Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=warn)
datamodule.eval_transform
# Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=warn)

next(iter(datamodule.train_data))["image"].shape
# torch.Size([3, 256, 256])
next(iter(datamodule.test_data))["image"].shape
# torch.Size([3, 256, 256])
