import torch

from einops import rearrange
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
from lightning import LightningDataModule


class Cifar10_Dataset(Dataset):
    def __init__(self, 
                 data_dir='your_save_path',
                 **kwargs):
        
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.cifar10_dataset = datasets.CIFAR10(
            root=data_dir,
            train=True,
            download=True,
            transform=self.transform
        )

    def __len__(self):
        return len(self.cifar10_dataset)

    def __getitem__(self, idx):
        image, _ = self.cifar10_dataset[idx]
        return image
    

class Cifar10_Loader(LightningDataModule):
    def __init__(self, data_dir, val_rate=0.1, batch_size=128, num_workers=0):
        super().__init__()
        self.batch_size = int(batch_size)
        self.num_workers = num_workers
        dataset = Cifar10_Dataset(data_dir)
        num = dataset.__len__()
        n_valid = int(num * val_rate)
        n_train = num

        print('\nBatch size: {}'.format(batch_size))
        print('Total number of images {}.'.format(n_train))
        print('\tTraining files:', n_train)
        print('\tValidation files:', n_valid)

        self.trSamples, self.vlSamples = random_split(dataset, lengths=[n_train - n_valid, n_valid])

    def train_dataloader(self):
        return DataLoader(self.trSamples, 
                          batch_size=self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers,
                          persistent_workers=True,
                          prefetch_factor=8,
                          pin_memory=True,
                          drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.vlSamples, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers,
                          persistent_workers=True,
                          prefetch_factor=8,
                          drop_last=True,
                          pin_memory=True)
