import torch
import kornia as K
from torch import Tensor
import kornia.augmentation as KA
from dataclasses import dataclass
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class AugCFG:
    image_size: Tuple[int, int]
    h_flip_prob: float

    affine_prob: float
    affine_rotate_range: Tuple[float, float]
    affine_translate_range: Tuple[float, float]
    affine_scale_range: Tuple[float, float]

    crop_scale_range: Tuple[float, float]
    crop_ratio_range: Tuple[float, float]
    crop_pad: int

    max_resize_ratio: float

    blur_prob: float = 0.3
    blur_kernel_size: Tuple[int, int] = (3, 3)
    blur_sigma: Tuple[float, float] = (0.1, 2.0)


class DiffAugment(nn.Module):
    def __init__(self, cfg: AugCFG):
        super().__init__()
        self.cfg = cfg

        self.geometric = KA.AugmentationSequential(
            KA.RandomHorizontalFlip(p=cfg.h_flip_prob),
            KA.RandomAffine(
                degrees=cfg.affine_rotate_range,
                translate=cfg.affine_translate_range,
                scale=cfg.affine_scale_range,
                p=cfg.affine_prob,
                shear=None,
            ),
            # KA.RandomResizedCrop(
            #     size=cfg.image_size,
            #     scale=cfg.crop_scale_range,
            #     ratio=cfg.crop_ratio_range,
            # ),
            # KA.RandomGaussianBlur(
            #     kernel_size=cfg.blur_kernel_size,
            #     sigma=cfg.blur_sigma,
            #     p=cfg.blur_prob,
            # ),
        )

    def forward(self, x: Tensor) -> Tensor:
        torch.seed()
        return self.geometric(x)
