from torch.utils.data import Dataset
from torchvision import datasets, transforms
import random
import torch
class CIFAR10(Dataset):
    def __init__(self, train=True, with_idx=False):

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform = transform_train if train else transform_test
        self.with_idx = with_idx
        self.cifar10 = datasets.CIFAR10(
            root='./data',
            train=train,
            download=True,
            transform=transform)

    def __getitem__(self, index):
        data, target = self.cifar10[index]
        if self.with_idx:
            return data, target, index
        else:
            return data, target

    def __len__(self):
        return len(self.cifar10)

class LazyCIFAR10(Dataset):
    def __init__(self, train=True, p_t=None, batch_size=128, with_idx=False):
        # Initialize base CIFAR10 dataset
        self.cifar10 = CIFAR10(train=train, with_idx=with_idx)
        # Lazy loading parameters
        self.epoch = 0  # epoch number starts from 0
        self.p_t = p_t if p_t is not None else lambda x: 0.5  # Default probability function
        self.prev_indices = [0 for _ in range(batch_size)]
        self.batch_size = batch_size
        self.cnt = 0
        self.with_idx = with_idx
        self.worker_id = None

    def update_epoch(self, epoch):
        self.epoch = epoch

    def __len__(self):
        return len(self.cifar10) 

    def __getitem__(self, idx):
        self.cnt += 1
        current_batch = self.cnt // self.batch_size + 1
        idx = self.cnt % self.batch_size
        
        # Determine if we should sample a new index
        if current_batch == 1 and self.epoch == 0:
            self.prev_indices[idx] = random.randint(0, len(self.cifar10)-1)
        else:
            if random.random() < self.p_t(current_batch):
                self.prev_indices[idx] = random.randint(0, len(self.cifar10)-1)
        
        # Return data based on with_idx flag
        if self.with_idx:
            data, target, index = self.cifar10[self.prev_indices[idx]]
            return data, target, index
        else:
            data, target = self.cifar10[self.prev_indices[idx]]
            return data, target