import torch.nn as nn
import kornia.augmentation as K

import random 

class AF(nn.Module):
    def __init__(self, s=(0.0, 0.0), r=(0.0, 0.0)):
        super(AF, self).__init__()
        self.affine = K.RandomAffine(
            degrees=r,
            shear=s,
            p=1.0,  # always apply
            same_on_batch=True,
        )

    def forward(self, image_and_cover):
        image, cover_image = image_and_cover  # ignore cover_image
        image = (image + 1) / 2
        return self.affine(image) * 2 - 1