import os
import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data.distributed import DistributedSampler

def get_cifar100(data_path, network_config):
    print("loading CIFAR100")
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    batch_size = network_config['batch_size']
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # Randomly crop with padding
        transforms.RandomHorizontalFlip(),  # 50% chance of flipping horizontally
        transforms.autoaugment.TrivialAugmentWide(),
        #transforms.RandomRotation(15),  # Random rotation
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),  # Normalize pixel values
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    ])

    trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
    train_sampler = DistributedSampler(trainset)

    trainloader = torch.utils.data.DataLoader(trainset, sampler = train_sampler,  batch_size=batch_size, shuffle=False, num_workers=8)

    testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test)
    test_sampler = DistributedSampler(testset)  # 创建一个分布式采样器
    testloader = torch.utils.data.DataLoader(testset, sampler=test_sampler, batch_size=batch_size, shuffle=False, num_workers=8)

    return trainloader, testloader, train_sampler, test_sampler
