# Portions of this file are adapted from
#   SAM (https://github.com/davda54/sam, MIT License).
# Changes: refactored for IAM algorithm, etc.

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random

class Cutout:
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.half_size = size // 2
        self.p = p

    def __call__(self, image):
        if torch.rand(1).item() > self.p:
            return image

        c, h, w = image.shape
        left = torch.randint(-self.half_size, w - self.half_size, [1]).item()
        top = torch.randint(-self.half_size, h - self.half_size, [1]).item()
        right = min(w, left + self.size)
        bottom = min(h, top + self.size)

        image[:, max(0, left): right, max(0, top): bottom] = 0
        return image

def get_cifar10_loaders(batch_size=128, num_workers=4):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        # Cutout()
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize(mean, std)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=2)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

def get_cifar100_loaders(batch_size=128, num_workers=4):
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=2)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, test_loader

def get_cifar100_loaders_semi(batch_size=128, num_workers=4, val_split=0.1, seed=42):
    # CIFAR-100 mean/std
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])
    transform_val_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    # Full train + val datasets
    train_full = datasets.CIFAR100(
        root='./data',
        train=True, download=True,
        transform=transform_train
    )
    val_full = datasets.CIFAR100(
        root='./data',
        train=True, download=False,
        transform=transform_val_test
    )
    # Test dataset
    test_dataset = datasets.CIFAR100(
        root='./data',
        train=False, download=True,
        transform=transform_val_test
    )

    # Create train/val split indices
    num_train = len(train_full)
    num_val   = int(num_train * val_split)
    gen       = torch.Generator().manual_seed(seed)
    indices   = torch.randperm(num_train, generator=gen)
    train_idx = indices[num_val:].tolist()
    val_idx   = indices[:num_val].tolist()

    train_subset = Subset(train_full, train_idx)
    val_subset   = Subset(val_full,   val_idx)

    # DataLoaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

def get_cifar10_loaders_semi(batch_size=128, num_workers=4, val_split=0.1, seed=42):
    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    transform_val_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    # Full train dataset (with augmentation) and validation dataset (no augmentation)
    train_full = datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transform_train
    )
    val_full = datasets.CIFAR10(
        root='./data', train=True, download=False,
        transform=transform_val_test
    )
    # Test dataset
    test_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True,
        transform=transform_val_test
    )

    # Create train/val split indices
    num_train = len(train_full)
    num_val = int(num_train * val_split)
    generator = torch.Generator().manual_seed(seed)
    # random_split uses generator internally, but we need same indices for both full and val_full
    indices = torch.randperm(num_train, generator=generator)
    val_indices = indices[:num_val].tolist()
    train_indices = indices[num_val:].tolist()

    # Subsets
    train_subset = Subset(train_full, train_indices)
    val_subset   = Subset(val_full,   val_indices)

    # DataLoaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader