from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
from typing import Any, Tuple, Optional, List, Dict
import torchvision.transforms as transforms
import os

class CIFAR10Warrper(Dataset):
    def __init__(self, 
        mode='train',
        num_samples: Optional[int] = -1,
        repeat: int = 1
        ):
        self.num_samples = num_samples
        if mode == 'train':
            transform = transforms.Compose(
                [
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5]),
                ]
            )
        elif mode == 'real':  
            transform = transforms.Compose(
                    [
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ]
            )
        else:  
            transform = transforms.Compose(
                [
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5]),
                ]
            )
    
        self.dataset = CIFAR10(
            os.path.join("data", "cifar10"),
            train=True,
            download=True,
            transform=transform,
        )
        self.chunk_size = len(self.dataset) if self.num_samples == -1 else self.num_samples
        self.repeat = repeat

    def __getitem__(self, index):
        index = index % self.chunk_size
        img, target = self.dataset[index]
        return {
            "images": img,
            "labels": target,
        }

    def __len__(self):
        return self.chunk_size * self.repeat
    

if __name__ == "__main__":
    dataset = CIFAR10Warrper(None, mode='train', num_samples=100)
    print(dataset[0]['images'].shape)