import torch
from torchvision import transforms
import torchvision.transforms.functional as TF

import random

class RandomTransformations:
    def __init__(self, num_samples, rotate, translate, size, normalize):
        self.num_samples = num_samples
        self.size = size
        self.normalize = normalize 

        self.rotate = [num for num in range(-int(rotate), int(rotate))]
        self.translate = [num for num in range(-int(translate), int(translate))]


    def __call__(self, x):
        output = []
        for _ in range(self.num_samples):
            rotate = random.choice(self.rotate)
            translate_x = random.choice(self.translate)
            translate_y = random.choice(self.translate)
            rand_img = TF.affine(
                    img=x,
                    angle=rotate,
                    translate=(translate_x, translate_y),
                    scale=1,
                    shear=0,
            )
            rand_img = TF.resize(rand_img, 32)
            rand_img = TF.to_tensor(rand_img)

            if self.normalize:
                rand_img = TF.normalize(rand_img, (0.5,), (0.5,))

            output.append(rand_img)

        x = TF.resize(x, 32)
        x = TF.to_tensor(x)
        if self.normalize:
            x = TF.normalize(x, (0.5,), (0.5,))

        output = torch.stack(output)
        return x, output
