import torch
import torchvision.transforms.functional as TF


class MyRotationTransform(torch.nn.Module):
    """Rotate by desired angle."""
    def __init__(self, p, angle):
        super().__init__()
        self.angle = angle
        self.p = p

    def forward(self, img):
        if torch.rand(1) < self.p:
            return TF.rotate(img, self.angle)
        return img


class OneEightyRotation(MyRotationTransform):
    """Rotate by 180 degrees."""
    def __init__(self, p):
        super().__init__(p=p, angle=180)


class NinetyRotation(MyRotationTransform):
    """Rotate by 90 degrees"""
    def __init__(self, p):
        super().__init__(p=p, angle=90)


class Identity(torch.nn.Module):
    """Identity transform."""
    def __init__(self, p=1.):
        super().__init__()

    def forward(self, img):
        return img
