import numpy as np
import torch
from torchvision.transforms import transforms

from augment.utils import Augment


def to_patch_position(img_h, img_w, patch_h, patch_w):
    patch_i = np.random.randint(img_h - patch_h)
    patch_j = np.random.randint(img_w - patch_w)
    return patch_i, patch_j


class CutPaste(Augment):
    def __init__(self, scale, ratio, color_jitter=0.):
        super().__init__()
        self.transform = transforms.ColorJitter(
            brightness=color_jitter,
            contrast=color_jitter,
            saturation=color_jitter,
            hue=color_jitter
        )
        self.scale = scale
        self.ratio = ratio

    def to_patch_size(self, img):
        img_area = img.shape[-2] * img.shape[-1]
        patch_area = self.scale * img_area
        patch_h = int(np.sqrt(patch_area / self.ratio))
        patch_w = int(np.sqrt(patch_area * self.ratio))
        return patch_h, patch_w

    def forward(self, imgs):
        assert isinstance(imgs, torch.Tensor)
        img_h, img_w = imgs.shape[-2:]
        patch_h, patch_w = self.to_patch_size(imgs)
        imgs_out = imgs.clone()
        for n in range(len(imgs)):
            i, j = to_patch_position(img_h, img_w, patch_h, patch_w)
            patch = self.transform(imgs[n, :, i:i + patch_h, j:j + patch_w])
            i, j = to_patch_position(img_h, img_w, patch_h, patch_w)
            imgs_out[n, :, i:i + patch_h, j:j + patch_w] = patch
        return imgs_out

    def get_parameters(self):
        return [self.scale, self.ratio]
