import torch
import torchvision
import ffcv.transforms
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder


def cifar100(batch_size, workers):
    # load training data
    normalize = torchvision.transforms.Normalize(
        (0.4914 * 255, 0.4822 * 255, 0.4465 * 255), (0.247 * 255, 0.243 * 255, 0.261 * 255)
    )
    train_pipeline = [
        SimpleRGBImageDecoder(),
        ffcv.transforms.RandomHorizontalFlip(),
        ffcv.transforms.RandomTranslate(padding=4, fill=(0, 0, 0)),
        ffcv.transforms.ToTensor(),
        ffcv.transforms.ToDevice('cuda:0', non_blocking=True),
        ffcv.transforms.ToTorchImage(),
        ffcv.transforms.Convert(torch.float32),
        normalize
    ]
    label_pipeline = [
        IntDecoder(),
        ffcv.transforms.ToTensor(),
        ffcv.transforms.ToDevice('cuda:0', non_blocking=True),
    ]
    train_loader = Loader(
        '/d1/xxx/DBQ/CIFAR100_train.beton',
        batch_size=batch_size,
        num_workers=workers,
        order=OrderOption.RANDOM,
        pipelines={'image': train_pipeline, 'label': label_pipeline}
    )

    # load validation data
    test_pipeline = [
        SimpleRGBImageDecoder(),
        ffcv.transforms.ToTensor(),
        ffcv.transforms.ToDevice('cuda:0', non_blocking=True),
        ffcv.transforms.ToTorchImage(),
        ffcv.transforms.Convert(torch.float32),
        normalize
    ]
    val_loader = Loader(
        '/d1/xxx/DBQ/CIFAR100_test.beton',
        batch_size=batch_size*4,
        num_workers=workers,
        pipelines={'image': test_pipeline, 'label': label_pipeline}
    )
    return train_loader, val_loader
