import torch
from torch import nn
import kornia.augmentation as K


class ImageAugmentations(nn.Module):
    def __init__(self, output_size, augmentations_number, p=0.7, variance=1.0):
        super().__init__()
        self.output_size = output_size
        self.augmentations_number = augmentations_number
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        # self.augmentations = K.RandomGaussianNoise(mean=0, std=variance, p=1)
        self.augmentations = nn.Sequential(
            K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"),  # type: ignore
            K.RandomPerspective(0.7, p=p),
        ).to(device)

        self.avg_pool = nn.AdaptiveAvgPool2d((self.output_size, self.output_size))

    def forward(self, input):
        """Extents the input batch with augmentations

        If the input is consists of images [I1, I2] the extended augmented output
        will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...]

        Args:
            input ([type]): input batch of shape [batch, C, H, W]

        Returns:
            updated batch: of shape [batch * augmentations_number, C, H, W]
        """
        # We want to multiply the number of images in the batch in contrast to regular augmentations
        # that do not change the number of samples in the batch)
        if self.augmentations_number <= 1:
            return input
        else:
            resized_images = self.avg_pool(input)
            print('resized to', resized_images.shape)
            resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1))

            batch_size = input.shape[0]
            # We want at least one non augmented image
            non_augmented_batch = resized_images[:batch_size]
            # augmented_batch = self.augmentations(resized_images[batch_size:])
            # variance = 1
            augmented_batch = self.augmentations(resized_images[batch_size:])
            updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0)

            # return input
            return updated_batch