import abc

import torch
from torch import nn
from torchvision.transforms import transforms


class Augment(nn.Module, abc.ABC):
    @abc.abstractmethod
    def get_parameters(self):
        pass


def apply_to_batch(transform):
    return transforms.Lambda(
        lambda x: torch.stack([transform(x_) for x_ in x])
    )
