import os
import torchvision
import torchvision.transforms as transforms
import torch
from tools.helpfunc import print_rank0
from torch.utils.data.distributed import DistributedSampler

def get_svhn(data_path, network_config):
    print_rank0("loading SVHN with ImageNet normalization")
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    batch_size = network_config['batch_size']
    # 更新为ImageNet的归一化参数
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]) 

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # 使用SVHN数据集
    trainset = torchvision.datasets.SVHN(root=data_path, split='train', download=True, transform=transform_train)
    train_sampler = DistributedSampler(trainset)
    trainloader = torch.utils.data.DataLoader(trainset, sampler=train_sampler, batch_size=batch_size, shuffle=False, num_workers=8)
    

    testset = torchvision.datasets.SVHN(root=data_path, split='test', download=True, transform=transform_test)
    test_sampler = DistributedSampler(testset)
    testloader = torch.utils.data.DataLoader(testset, sampler=test_sampler, batch_size=batch_size, shuffle=False, num_workers=8)

    return trainloader, testloader, train_sampler
