import os
from torch.utils.data import Dataset, Subset, DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import random
from collections import defaultdict


class MNISTDataset(Dataset):
    def __init__(self, train):
        # Create data directory if it doesn't exist
        root = '../../data'
        os.makedirs(root, exist_ok=True)
        self.train = train
        
        # Define transforms: resize to 32x32 and convert to tensor
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ])
        
        # Download MNIST if not exists
        self.mnist = MNIST(root=root, train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        image, label = self.mnist[idx]
        return image, label


def get_data(args, num_workers=1):
    # Generate or load data depending on experiment
    train_data = MNISTDataset(train=True)
    val_data = MNISTDataset(train=True)

    # Create dataloaders
    train_loader = DataLoader(
        train_data,
        batch_size=args.train_batchsize,
        shuffle=True,
        num_workers=num_workers
    )
    val_loader = DataLoader(
        val_data,
        batch_size=args.val_batchsize,
        shuffle=True,
        num_workers=num_workers
    )

    # Shuffle the indices
    num_samples = len(val_data)
    shuffled_indices = list(range(num_samples))
    random.shuffle(shuffled_indices)

    # Collect val_batchsize samples per digit
    label_to_indices = defaultdict(list)
    samples_per_digit = args.val_batchsize

    for idx in shuffled_indices:
        _, label = val_data[idx]
        if len(label_to_indices[label]) < samples_per_digit:
            label_to_indices[label].append(idx)
        if all(len(v) == samples_per_digit for v in label_to_indices.values()):
            break

    # Create a DataLoader for each digit
    digit_loaders = {}
    for digit in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
        subset = Subset(val_data, label_to_indices[digit])
        loader = DataLoader(
            subset,
            batch_size=args.val_batchsize,
            shuffle=True,
            num_workers=num_workers
        )
        digit_loaders[digit] = loader

    return train_loader, val_loader, digit_loaders

