from torchvision.transforms import transforms

from augment.utils import Augment, apply_to_batch


class CutOut(Augment):
    def __init__(self, scale, ratio):
        super().__init__()
        self.scale = scale
        self.ratio = 1 / ratio
        self.transform = apply_to_batch(transforms.RandomErasing(
            p=1.0,
            scale=(scale, scale),
            ratio=(self.ratio, self.ratio),
        ))

    def forward(self, x):
        return self.transform(x)

    def get_parameters(self):
        return [self.scale, self.ratio]
