import torchvision
from cifar10 import *
from cifar100 import *
from sub_dataset import *
import numpy as np
import torch
import data

test_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(32),
        # torchvision.transforms.Grayscale(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [.5], [.5]
        )
    ])

dataset = torchvision.datasets.CIFAR100(
    root = '../../datasets',
    train = False,
    download = True,
    transform = torchvision.transforms.ToTensor(),
)
# print(dir(dataset))
# print(dataset.test_list)
# np.save('CIFAR-10_images.npz',dataset.test_dataset)
images = []
labels = []
for i in range(10000):
    # images.append(dataset.test_dataset[i][0].numpy())
    # labels.append(dataset.test_dataset[i][1])
    images.append(dataset[i][0].numpy())
    labels.append(dataset[i][1])
images = np.array(images)
labels = np.array(labels)
np.savez('CIFAR-100_images.npz', images=images,labels=labels)

# cifar100 = data.CIFAR100()
# testset = cifar100.test_dataset
# dataloader = torch.utils.data.DataLoader(
#     dataset, batch_size=50, shuffle=False, num_workers=8
# )
# print(testset[100])
# for i, ((x1, x2), _) in enumerate(dataloader):
#     if i == 2:
#         print(x1, x2)
#         break