import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
import os

save_dir = '/home/admin1/Syh/Training-free-quant/mixed_bit/cifar10/train'
os.makedirs(save_dir, exist_ok=True)

transform = transforms.ToTensor()

cifar10_dataset = datasets.CIFAR10(root='/home/admin1/dataset/Dataset/cifar-10', train=True, download=True, transform=transform)

label_to_save = 2
for idx, (image, label) in enumerate(cifar10_dataset):
    if label == label_to_save:
        img = transforms.ToPILImage()(image)
        img.save(os.path.join(save_dir, f"{idx}.png"))

print("end")