import torch
from torchvision import datasets, transforms
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import TensorDataset

def SelectMNIST(num_digits=2, sample_per_digit=1024):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    filename="./dataset/SmallMNIST_numdigits_%d_sampleperdigit_%d.data" % (num_digits,sample_per_digit)

    digits = set(range(num_digits))
    filtered_data = []
    for i in range(len(train_dataset)):
        image, label = train_dataset[i]
        if label in digits:
            filtered_data.append((image, label))

    # Shuffle the filtered data
    torch.manual_seed(0)  # for reproducibility
    torch.randperm(len(filtered_data))
    filtered_data = [filtered_data[i] for i in range(len(filtered_data))]

    # Select sample_per_digit samples of digit 0 and 1000 samples of digit 5
    selected_data = []
    count = {}
    
    images = torch.empty(0,1,28,28)
    labels = torch.empty(0)
    
    for label in digits:
        count[label] = 0;
    for image, label in filtered_data:
        if count[label] < sample_per_digit:
            images = torch.cat((images,image.unsqueeze(0)))
            labels = torch.cat((labels,torch.tensor([label])))
            count[label] += 1
        
        if count[label] == sample_per_digit:
            digits.discard(label)
        
        if len(digits)==0:
            break
    
    data_dict = {
        "images": images,
        "labels": labels
    }

    # Save both tensors to the same file
    torch.save(data_dict, filename)
