from typing import Optional
import random
import torch
import torch.nn.functional as F
from PIL import Image


class RandAugment:
    def __init__(self, identity: bool = False, brightness: Optional[float] = None, translation: Optional[float] = None, cutout_v1: Optional[float] = None, cutout_v2: Optional[float] = None, rotation: bool = False):
        self.identity = identity
        self.brightness = brightness
        self.translation = translation
        self.cutout_v1 = cutout_v1
        self.cutout_v2 = cutout_v2
        self.rotation = rotation

        if sum([bool(cutout_v1), bool(cutout_v2)]) > 1:
            raise ValueError(f"At most one of {cutout_v1=}, {cutout_v2=} can be True.")

        self.num_rand_augments = 0
        if self.identity:
            self.num_rand_augments += 1
        if self.brightness is not None:
            self.num_rand_augments += 1
        if self.translation is not None:
            self.num_rand_augments += 1
        if self.cutout_v1 is not None:
            self.num_rand_augments += 1
        if self.cutout_v2 is not None:
            self.num_rand_augments += 1
        if self.rotation:
            self.num_rand_augments += 1

    def id(self, x):
        return x

    def rand_brightness(self, x, r=None):
        if r is None:
            u = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
            r = torch.exp(torch.log(torch.tensor(self.brightness, dtype=x.dtype, device=x.device)) * (2*u - 1))
        else:
            r = r + 1
        x = x * r
        return x, r - 1

    def rand_translation(self, x, translation_x=None, translation_y=None):
        shift_x, shift_y = int(x.size(2) * self.translation + 0.5), int(x.size(3) * self.translation + 0.5)
        if translation_x is None or translation_y is None:
            translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
            translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(x.size(2), dtype=torch.long, device=x.device),
            torch.arange(x.size(3), dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
        grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
        x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
        x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
        return x, (translation_x, translation_y)

    def rand_cutout_v1(self, x, cutout_x=None, cutout_y=None, offset_x=None, offset_y=None):
        if cutout_x is None or cutout_y is None:
            cutout_x = torch.randint(0, int(x.size(2) * self.cutout_v1 + 0.5) + 1, size=[x.size(0), 1], device=x.device)
            cutout_y = torch.randint(0, int(x.size(3) * self.cutout_v1 + 0.5) + 1, size=[x.size(0), 1], device=x.device)
        if offset_x is None or offset_y is None:
            low_x = (cutout_x - 1) // 2
            high_x = x.size(2) - cutout_x // 2
            offset_x = low_x + ((high_x - low_x) * torch.rand_like(low_x.float())).long()
            low_y = (cutout_y - 1) // 2
            high_y = x.size(3) - cutout_y // 2
            offset_y = low_y + ((high_y - low_y) * torch.rand_like(low_y.float())).long()

        x_start = offset_x - (cutout_x - 1) // 2
        y_start = offset_y - (cutout_y - 1) // 2
        x_end = x_start + cutout_x
        y_end = y_start + cutout_y

        assert (x_start >= 0).all() and (y_start >= 0).all() and (x_end <= x.size(2)).all() and (y_end <= x.size(3)).all(), f"{x_start.view(-1)=}, {x_end.view(-1)=}, {y_start.view(-1)=}, {y_end.view(-1)=}"

        h = torch.arange(x.size(2), device=x.device).view(1, -1, 1)
        w = torch.arange(x.size(3), device=x.device).view(1, 1, -1)

        mask = (
            (h >= x_start.view(-1, 1, 1)) & (h < x_end.view(-1, 1, 1)) &
            (w >= y_start.view(-1, 1, 1)) & (w < y_end.view(-1, 1, 1))
        )
        masks = torch.ones((x.size(0), x.size(2), x.size(3)), dtype=x.dtype, device=x.device)
        masks[mask] = 0
        x = x * masks.unsqueeze(1)
        return x, (cutout_x, cutout_y, offset_x, offset_y)

    def rand_cutout_v2(self, x, cutout_x=None, cutout_y=None, offset_x=None, offset_y=None):
        if cutout_x is None or cutout_y is None:
            cutout_x = torch.randint(1, int(x.size(2) * self.cutout_v2 + 0.5) + 1, size=[x.size(0), 1], device=x.device)
            cutout_y = torch.randint(1, int(x.size(3) * self.cutout_v2 + 0.5) + 1, size=[x.size(0), 1], device=x.device)
        if offset_x is None or offset_y is None:
            low_x = (cutout_x - 1) // 2
            high_x = x.size(2) - cutout_x // 2
            offset_x = low_x + ((high_x - low_x) * torch.rand_like(low_x.float())).long()
            low_y = (cutout_y - 1) // 2
            high_y = x.size(3) - cutout_y // 2
            offset_y = low_y + ((high_y - low_y) * torch.rand_like(low_y.float())).long()

        x_start = offset_x - (cutout_x - 1) // 2
        y_start = offset_y - (cutout_y - 1) // 2
        x_end = x_start + cutout_x
        y_end = y_start + cutout_y

        assert (x_start >= 0).all() and (y_start >= 0).all() and (x_end <= x.size(2)).all() and (y_end <= x.size(3)).all(), f"{x_start.view(-1)=}, {x_end.view(-1)=}, {y_start.view(-1)=}, {y_end.view(-1)=}"

        h = torch.arange(x.size(2), device=x.device).view(1, -1, 1)
        w = torch.arange(x.size(3), device=x.device).view(1, 1, -1)

        mask = (
            (h >= x_start.view(-1, 1, 1)) & (h < x_end.view(-1, 1, 1)) &
            (w >= y_start.view(-1, 1, 1)) & (w < y_end.view(-1, 1, 1))
        )
        masks = torch.ones((x.size(0), x.size(2), x.size(3)), dtype=x.dtype, device=x.device)
        masks[mask] = 0
        x = x * masks.unsqueeze(1)
        return x, (cutout_x, cutout_y, offset_x, offset_y)

    def rand_rotation(self, x, k=None):
        if k is None:
            k = torch.randint(0, 4, size=[x.size(0), 1], device=x.device)  # Randomly choose 0, 1, 2, or 3 for each item in the batch
        x_rotated = torch.stack([torch.rot90(x[i], k=int(k[i][0]), dims=[1, 2]) for i in range(x.size(0))])
        return x_rotated, k

    def __call__(self, x, params, index=None):
        x_candidates = []
        augment_info = []

        if self.identity:
            x_i = self.id(x)
            x_candidates.append(x_i)
            augment_info.append(('identity', None))
        if self.brightness is not None:
            x_b, p = self.rand_brightness(x, params.get('brightness', None))
            if 'brightness' not in params:
                params['brightness'] = p
            else:
                assert (params['brightness'] == p).all()
            x_candidates.append(x_b)
            augment_info.append(('brightness', 1))
        if self.translation is not None:
            x_t, p = self.rand_translation(x, *params.get('translation', (None, None)))
            if 'translation' not in params:
                params['translation'] = p
            else:
                assert (params['translation'][0] == p[0]).all() and (params['translation'][1] == p[1]).all()
            x_candidates.append(x_t)
            augment_info.append(('translation', 2))
        if self.cutout_v1 is not None:
            x_c6, p = self.rand_cutout_v1(x, *params.get('cutout_v1', (None, None, None, None)))
            if 'cutout_v1' not in params:
                params['cutout_v1'] = p
            else:
                assert (params['cutout_v1'][0] == p[0]).all() and (params['cutout_v1'][1] == p[1]).all() and (params['cutout_v1'][2] == p[2]).all() and (params['cutout_v1'][3] == p[3]).all()
            x_candidates.append(x_c6)
            augment_info.append(('cutout_v1', 2))
        if self.cutout_v2 is not None:
            x_c7, p = self.rand_cutout_v2(x, *params.get('cutout_v2', (None, None, None, None)))
            if 'cutout_v2' not in params:
                params['cutout_v2'] = p
            else:
                assert (params['cutout_v2'][0] == p[0]).all() and (params['cutout_v2'][1] == p[1]).all() and (params['cutout_v2'][2] == p[2]).all() and (params['cutout_v2'][3] == p[3]).all()
            x_candidates.append(x_c7)
            augment_info.append(('cutout_v2', 2))
        if self.rotation:
            x_r, p = self.rand_rotation(x, params.get('rotation', None))
            if 'rotation' not in params:
                params['rotation'] = p
            else:
                assert (params['rotation'] == p).all()
            x_candidates.append(x_r)
            augment_info.append(('rotation', 1))

        if index is None:
            index = torch.randint(0, self.num_rand_augments, (x.shape[0],), device=x.device)
        x_candidates = torch.stack(x_candidates, dim=0)
        x = x_candidates[index, torch.arange(x.shape[0], device=x.device)]

        param_dim = sum(dim for _, dim in augment_info if dim is not None)
        augment_params = torch.zeros(x.size(0), param_dim, device=x.device)

        offset = 0
        for aug_idx, (name, dim) in enumerate(augment_info):
            if not dim: continue

            mask = (index == aug_idx)

            if name == 'brightness':
                data = params['brightness'].view(-1, 1)
            elif name == 'translation':
                data = torch.cat([params['translation'][0].view(-1, 1), params['translation'][1].view(-1, 1)], dim=1)
            elif name == 'cutout_v1':
                data = torch.cat([params['cutout_v1'][0].view(-1, 1), params['cutout_v1'][1].view(-1, 1)], dim=1)
            elif name == 'cutout_v2':
                data = torch.cat([params['cutout_v2'][0].view(-1, 1), params['cutout_v2'][1].view(-1, 1)], dim=1)
            elif name == 'rotation':
                data = params['rotation'].view(-1, 1)
            else:
                raise NotImplementedError

            augment_params[mask, offset:offset+dim] = data[mask].float()
            offset += dim

        return x, params, augment_params, index

