import torch

from anomalib.deploy import ExportType, TorchInferencer

engine = Engine()
model = Patchcore()

train_transform = Compose(
    [
        RandomHorizontalFlip(p=0.5),
        Resize((256, 256)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

eval_transform = Compose(
    [
        Resize((256, 256)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

datamodule = MVTec(train_transform=train_transform, eval_transform=eval_transform)

engine.fit(model, datamodule=datamodule)

# after running fit, the used eval_transform will be stored in the model
model.transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

# when exporting the trained model, the transforms are added in the export
engine.export(model, export_type=ExportType.TORCH, export_root="./export_folder")

inferencer = TorchInferencer("export_folder/weights/torch/model.pt")
inferencer.model.transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

# since the transforms are included in the forward pass, they are applied automatically
image = torch.rand((3, 900, 900))
prediction = inferencer.predict(image)
prediction.pred_label
# <LabelName.ABNORMAL: 1>
