import torch


def diff_augment(
        x: torch.Tensor,
        brightness: torch.Tensor,
        saturation: torch.Tensor,
        contrast: torch.Tensor,
        translation_x: torch.Tensor,
        translation_y: torch.Tensor,
        offset_x: torch.Tensor,
        offset_y: torch.Tensor,
        channels_first=True
) -> torch.Tensor:
    if not channels_first:
        x = x.permute(0, 3, 1, 2)
    x: torch.Tensor = rand_brightness(x, brightness)
    x: torch.Tensor = rand_saturation(x, saturation)
    x: torch.Tensor = rand_contrast(x, contrast)
    x: torch.Tensor = rand_translation(x, translation_x, translation_y)
    x: torch.Tensor = rand_cutout(x, offset_x, offset_y)
    if not channels_first:
        x: torch.Tensor = x.permute(0, 2, 3, 1)
    x: torch.Tensor = x.contiguous()
    return x


def rand_brightness(x: torch.Tensor, brightness: torch.Tensor):
    x: torch.Tensor = x + brightness
    return x


def rand_saturation(x: torch.Tensor, saturation: torch.Tensor):
    x_mean: torch.Tensor = x.mean(dim=1, keepdim=True)
    x: torch.Tensor = (x - x_mean) * saturation + x_mean
    return x


def rand_contrast(x: torch.Tensor, contrast: torch.Tensor):
    x_mean: torch.Tensor = x.mean(dim=[1, 2, 3], keepdim=True)
    x: torch.Tensor = (x - x_mean) * contrast + x_mean
    return x


def rand_translation(
        x: torch.Tensor,
        translation_x: torch.Tensor,
        translation_y: torch.Tensor
):
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device)
    )
    grid_x: torch.Tensor = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y: torch.Tensor = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad: torch.Tensor = torch.nn.functional.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x: torch.Tensor = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(
        x: torch.Tensor,
        offset_x: torch.Tensor,
        offset_y: torch.Tensor,
        ratio: float = 0.2
) -> torch.Tensor:
    cutout_size: tuple[int, int] = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device)
    )
    grid_x: torch.Tensor = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y: torch.Tensor = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask: torch.Tensor = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x: torch.Tensor = x * mask.unsqueeze(1)
    return x
