from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torch

def get_cifar10(batch_size, num_batches):
    if num_batches > 780:
        num_batches = 780
    # Define a simple transformation to match MobileNetV2's expected input
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    # Load the CIFAR-10 dataset and select a single image
    dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
    subset_indices = list(range(batch_size*num_batches))
    single_batch_dataset = Subset(dataset, subset_indices)  # Use only the first batch for simplicity
    dataloader = DataLoader(single_batch_dataset, batch_size=batch_size, shuffle=True)
    return dataloader