import torch
import numpy as np
from utils.config import *

class Cutout(object):
    def __init__(self, n_holes, args, length):
        self.n_holes = n_holes
        self.length = length

        cfg, is_vit = DATA_CONFIG[args.dataset.lower()], "vit" in args.model.lower()

        self.mean = [0.485, 0.456, 0.406] if is_vit else cfg['mean']
        self.std = [0.229, 0.224, 0.225] if is_vit else cfg['std']

    def __call__(self, img, target):
        batch_size, C, H, W = img.size()

        mean_tensor = torch.tensor(self.mean, device=img.device).view(1, C, 1, 1)
        std_tensor = torch.tensor(self.std, device=img.device).view(1, C, 1, 1)
        replacement = (0 - mean_tensor) / std_tensor

        for i in range(batch_size):
            cx = np.random.randint(W)
            cy = np.random.randint(H)
            x1 = np.clip(cx - self.length // 2, 0, W)
            y1 = np.clip(cy - self.length // 2, 0, H)
            x2 = np.clip(cx + self.length // 2, 0, W)
            y2 = np.clip(cy + self.length // 2, 0, H)
            img[i, :, y1:y2, x1:x2] = replacement

        return img, target

class CutMix(object):
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, img, target):
        lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 1
        index = torch.randperm(img.size(0), device=img.device)

        batch_size, C, H, W = img.size()
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        img[:, :, bby1:bby2, bbx1:bbx2] = img[index, :, bby1:bby2, bbx1:bbx2]
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
        target = lam * target + (1. - lam) * target[index]
        return img, target

class MixUp(object):
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, img, target):
        lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 1
        index = torch.randperm(img.size(0), device=img.device)
        
        img = lam * img + (1 - lam) * img[index]
        target = lam * target + (1 - lam) * target[index]
        return img, target