from torchvision.datasets import CIFAR10
from torchvision import transforms
from .data_utils import *


def get_dataset(dpath=None, ):
    dataset = CIFAR10("./data/cifar", train=True,
                      transform=transforms.Compose([transforms.Resize(32),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(
                                                        mean=(0.5, 0.5, 0.5),
                                                        std=(0.5, 0.5, 0.5)),
                                                    ]), download=True)
    return dataset

def get_fl_dataset(args):
    dataset = get_dataset()
    return process_data(args, list(dataset))

if __name__ == '__main__':
    dataset = get_dataset()
    print(len(dataset))
    d, l = dataset[0]
    print(d.shape, d.min(), d.max())
