import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

# -----------------------------
# 数据增强 / 预处理
# -----------------------------
NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
te_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*NORM)
])
tr_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*NORM)
])

# CIFAR-10-C 常见的失真类型
common_corruptions = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
    'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
]

# -----------------------------
# 训练集 DataLoader
# -----------------------------
def prepare_train_data(args):
    """
    仅加载 CIFAR-10 原始训练集
    """
    print('Preparing training data (CIFAR-10)...')
    if args.dataset == 'cifar10':
        trset = torchvision.datasets.CIFAR10(
            root=args.dataroot,
            train=True,
            download=True,
            transform=tr_transforms
        )
    else:
        raise Exception('Dataset not found! Only cifar10 is supported here.')

    if not hasattr(args, 'workers'):
        args.workers = 1
    
    trloader = torch.utils.data.DataLoader(
        trset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers
    )
    
    return trset, trloader

# -----------------------------
# 测试集 DataLoader
# -----------------------------
def prepare_test_data(args):
    """
    仅加载 CIFAR-10-C 测试集
    - args.corruption 指定失真类型
    - args.level      指定失真等级 [1~5]
    """
    print('Preparing test data (CIFAR-10-C)...')
    
    if args.dataset == 'cifar10':
        tesize = 10000  # CIFAR-10 测试集大小
        if args.corruption in common_corruptions:
            # 加载指定失真类型与失真等级的 CIFAR-10-C
            print('Test on %s level %d' % (args.corruption, args.level))
            # CIFAR-10-C 中，每种失真类型有 5 个等级 * 10000 张图像 = 50000
            # 每个 level 抽取 10000 张图
            teset_raw = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' % (args.corruption))
            teset_raw = teset_raw[(args.level - 1) * tesize : args.level * tesize]
            
            # 利用 torchvision.datasets.CIFAR10 的结构，但替换其 data
            teset = torchvision.datasets.CIFAR10(
                root=args.dataroot,
                train=False,
                download=True,
                transform=te_transforms
            )
            teset.data = teset_raw
        else:
            raise Exception(
                f'Unknown corruption type "{args.corruption}". '
                f'Please choose from {common_corruptions}.'
            )
    else:
        raise Exception('Dataset not found! Only cifar10 is supported here.')

    if not hasattr(args, 'workers'):
        args.workers = 1

    teloader = torch.utils.data.DataLoader(
        teset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers
    )
    
    return teset, teloader
