import torch
import torch.nn as nn
import kornia.augmentation as K
import kornia

class RC(nn.Module):
    def __init__(self, min_crop_size=13,max=13):
        super(RC, self).__init__()
        output_size=(128, 128)
        self.min_crop_size = int(min_crop_size)
        self.max_crop_size = int(min_crop_size)
        self.output_size = tuple(output_size)

        sizes = torch.arange(self.min_crop_size, self.max_crop_size + 1, dtype=torch.float32)
        base = float(self.output_size[0]) 
        weights = (base / sizes) ** 2     

        probs = weights / weights.sum()

        self.register_buffer("sizes_tbl", sizes)        # [K]
        self.register_buffer("probs_tbl", probs)        # [K]

        # 统一的 resize
        self.resize = kornia.geometry.transform.Resize(self.output_size)

    def forward(self, image_and_cover):
        image, _ = image_and_cover         # image: (B,C,H,W)
        B, C, H, W = image.shape

        idx = torch.multinomial(self.probs_tbl, num_samples=1).item()
        s = int(self.sizes_tbl[idx].item())

        s_eff = int(min(s, H, W))

        crop = K.RandomCrop(size=(s_eff, s_eff), p=1.0, keepdim=False, same_on_batch=False)
        cropped = crop(image)

        resized = self.resize(cropped)
        return resized

    
class Resize(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()
        self.alpha = alpha

    def forward(self, image_and_cover):
        image, cover = image_and_cover
        B, C, H, W = image.shape
        new_h = max(1, int(round(H * self.alpha)))
        new_w = max(1, int(round(W * self.alpha)))

        resize_op = kornia.geometry.transform.Resize((new_h, new_w))
        resized = resize_op(image)
        return resized
    
class Crop(nn.Module):
    def __init__(self, size: float):
        super().__init__()
        self.size = size

    def forward(self, image_and_cover):
        image, _ = image_and_cover         # image: (B,C,H,W)
        B, C, H, W = image.shape

        crop = K.RandomCrop(size=(self.size, self.size), p=1.0, keepdim=False, same_on_batch=False)
        cropped = crop(image)

        return cropped