
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import pyro.distributions as dist
#from mpi4py import MPI
from scipy.io import savemat

# Device selection.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ == '__main__':

    node = 1 #int(sys.argv[1])
    
    # --- Data Loading using torchvision.datasets.MNIST ---
    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    full_val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    N_tr = 1000  # Number of training samples to use.
    N_val = 1000  # Number of validation samples to use.
    
    train_indices = list(range(len(full_train_dataset)))
    #random.shuffle(train_indices)
    train_subset = Subset(full_train_dataset, train_indices[:N_tr])
    
    val_indices = list(range(len(full_val_dataset)))
    #random.shuffle(val_indices)
    val_subset = Subset(full_val_dataset, val_indices[:N_val])
    
    train_loader_full = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
    x_train, y_train = next(iter(train_loader_full))
    #x_train, y_train = x_train.to(device), y_train.to(device)
    
    val_loader_full = DataLoader(val_subset, batch_size=len(val_subset), shuffle=False)
    x_val, y_val = next(iter(val_loader_full))
    #x_val, y_val = x_val.to(device), y_val.to(device)
    
    x_train = x_train.cpu().numpy()
    y_train = y_train.cpu().numpy()
    x_test  = x_val.cpu().numpy()
    y_test  = y_val.cpu().numpy()

    # --- Save the Results ---
    savemat(f'BayesianNN_MNIST_data.mat', {
        'x_train': x_train,
        'y_train': y_train,
        'x_test': x_test,
        'y_test': y_test
    })
    
