import os 
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def dataset_to_numpy(dataset, batch_size):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    images_list = []
    labels_list = []
    for images, labels in loader:
        # shape of images: (B, 1, 28, 28)
        images_list.append(images.numpy())
        labels_list.append(labels.numpy())
    images = np.concatenate(images_list, axis=0)   # (N, 1, 28, 28)
    labels = np.concatenate(labels_list, axis=0)   # (N,)
    return images, labels

def save_mnist_npz(save_path, batch_size=1000):
    # Create the save path directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)
    
    # Define transformation: convert to tensor and normalize
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 1. Download and load the MNIST dataset (ensure transform is applied)
    train_dataset = datasets.MNIST(root='/public/torchvision_datasets',
                                   train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='/public/torchvision_datasets',
                                  train=False, download=True, transform=transform)
    
    # 2. Load data using DataLoader and convert to numpy arrays
    train_data, train_labels = dataset_to_numpy(train_dataset, batch_size=batch_size)
    test_data, test_labels = dataset_to_numpy(test_dataset, batch_size=batch_size)

    # 3. Shuffle the training data
    indices = np.arange(len(train_data))
    np.random.shuffle(indices)
    train_data = train_data[indices]
    train_labels = train_labels[indices]

    # 4. Save files to the specified path
    train_save_path = os.path.join(save_path, 'mnist_train.npz')
    np.savez(train_save_path, images=train_data, labels=train_labels)
    print(f'Saved {train_save_path}, images shape: {train_data.shape}, labels shape: {train_labels.shape}')

    test_save_path = os.path.join(save_path, 'mnist_test.npz')
    np.savez(test_save_path, images=test_data, labels=test_labels)
    print(f'Saved {test_save_path}, images shape: {test_data.shape}, labels shape: {test_labels.shape}')

if __name__ == '__main__':
    save_mnist_npz(save_path='./')

