from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import OrderedDict
import tensorflow as tf
from PIL import Image
import os
import itertools
from typing import List
from torch.cuda.amp import GradScaler, autocast

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
import gc

from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
from pathlib import Path

def to_chunks(it, size):
  size = int(math.ceil(size))
  it = iter(it)
  return iter(lambda: tuple(itertools.islice(it, size)), ())

device = torch.device("cuda:1")

cifar10_datasets = {
    'train': datasets.CIFAR10('/tmp', train=True, download=True),
    'test': datasets.CIFAR10('/tmp', train=False, download=True)
}

for (name, ds) in cifar10_datasets.items():
    writer = DatasetWriter(f'/tmp/cifar_{name}.beton', {
        'image': RGBImageField(),
        'label': IntField()
    })
    writer.from_indexed_dataset(ds)
