

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch.utils.data import DataLoader
from torch.utils.data import random_split


import torchvision.datasets as datasets


def CIFAR10(batch_size):
    # Load and preprocess CIFAR-10 dataset
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Define the training dataset
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True,transform=transform)
    val_dataset, train_dataset = torch.utils.data.random_split(train_dataset, [128*40, len(train_dataset) - 128*40])
    # Define the dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    # Download and load the test set
    test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)


    return train_loader, val_loader, test_loader
