import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

def get_cifar100_loaders(batch_size, data_dir, num_workers=4, pin_memory=True, augmentation_config=None):
    """
    Function to get CIFAR-100 train and test data loaders.

    Args:
        batch_size (int): Batch size for the data loaders.
        data_dir (str): Directory where the CIFAR100 dataset will be downloaded/stored.
        num_workers (int): Number of workers for the DataLoader.
        pin_memory (bool): Whether to use pin_memory in DataLoader.
        augmentation_config (dict): Dictionary containing augmentation configuration.

    Returns:
        train_loader, test_loader: DataLoader instances for training and testing.
    """

    # Default normalization for CIFAR-100
    mean = augmentation_config.get('normalization', {}).get('mean', [0.5071, 0.4865, 0.4409])
    std = augmentation_config.get('normalization', {}).get('std', [0.2673, 0.2564, 0.2761])

    # Augmentations for training
    transform_train_list = []

    if augmentation_config.get('random_flip', True):
        transform_train_list.append(transforms.RandomHorizontalFlip())
    
    if augmentation_config.get('random_crop', True):
        crop_padding = augmentation_config.get('crop_padding', 4)
        transform_train_list.append(transforms.RandomCrop(32, padding=crop_padding))

    transform_train_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transform_train = transforms.Compose(transform_train_list)

    # Transform for testing (no augmentation)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    # Load the CIFAR-100 dataset
    train_dataset = torchvision.datasets.CIFAR100(
        root=data_dir,
        train=True,
        transform=transform_train,
        download=True
    )

    test_dataset = torchvision.datasets.CIFAR100(
        root=data_dir,
        train=False,
        transform=transform_test,
        download=True
    )

    # DataLoader for training
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    # DataLoader for testing
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    return train_loader, test_loader
