from torchvision.transforms.v2 import Compose, RandomAdjustSharpness, RandomHorizontalFlip, Resize

train_transform = Compose(
    [
        RandomAdjustSharpness(sharpness_factor=0.7, p=0.5),
        RandomHorizontalFlip(p=0.5),
        Resize((256, 256), antialias=True),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)
eval_transform = Compose(
    [
        Resize((256, 256), antialias=True),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

datamodule = MVTec(train_transform=train_transform, eval_transform=eval_transform)
datamodule.prepare_data()
datamodule.setup()

datamodule.train_transform
# Compose(
#       RandomAdjustSharpness(p=0.5, sharpness_factor=0.7)
#       RandomHorizontalFlip(p=0.5)
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )
datamodule.eval_transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )
