# dataset.py

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

class Cutout(object):
    """
    Implements the Cutout data augmentation technique.
    """
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.p = p

    def __call__(self, image):
        if torch.rand([1]).item() > self.p:
            return image
        
        h, w = image.size(1), image.size(2)
        mask = np.ones((h, w), np.float32)
        
        y = np.random.randint(h)
        x = np.random.randint(w)
        
        y1 = np.clip(y - self.size // 2, 0, h)
        y2 = np.clip(y + self.size // 2, 0, h)
        x1 = np.clip(x - self.size // 2, 0, w)
        x2 = np.clip(x + self.size // 2, 0, w)

        mask[y1:y2, x1:x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(image)
        image = image * mask
        return image

def get_cifar10(datadir: str):
    """
    Fetches the CIFAR-10 dataset.
    """
    mean = np.array([125.3, 123.0, 113.9]) / 255.0
    std = np.array([63.0, 62.1, 66.7]) / 255.0
    
    train_transform = transforms.Compose([
        transforms.RandomCrop(size=(32, 32), padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        # Cutout(size=16, p=0.5) # Optional data augmentation
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_data = torchvision.datasets.CIFAR10(root=datadir, train=True, transform=train_transform, download=True)
    val_data = torchvision.datasets.CIFAR10(root=datadir, train=False, transform=test_transform, download=True)
    
    return train_data, val_data, 10 # n_classes

def get_dataloaders(args, train_data, val_data):
    """
    Constructs training and validation dataloaders based on the arguments.
    """
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
        drop_last=True,
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
    )
    
    return train_loader, val_loader