import torch
import torchvision

import torchvision.transforms as transforms
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def gen_data(train_set = True):
    # Load CIFAR-10 dataset
    transform = transforms.Compose(
        [transforms.ToTensor(),])
    if train_set == True:
        dataset = torchvision.datasets.SVHN(root='./data', split='train',
                                                download=True, transform=transform)
    else:
        dataset = torchvision.datasets.SVHN(root='./data', split='test',
                                                download=True, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

    # Extract data and labels
    list_data = []
    list_labels = []
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        list_data.append(data.cpu().numpy())
        list_labels.append(target.cpu().numpy())

    # Concatenate data and labels
    list_data = np.concatenate(list_data, axis=0)
    list_data = np.transpose(list_data, (0, 2, 3, 1))
    list_labels = np.concatenate(list_labels, axis=0)

    return list_data, list_labels

train_data, train_labels = gen_data(train_set = True)
test_data, test_labels = gen_data(train_set = False)

# Save to NumPy file
np.save('SVHN_train_data.npy', train_data)
np.save('SVHN_train_labels.npy', train_labels)
np.save('SVHN_test_data.npy', test_data)
np.save('SVHN_test_labels.npy', test_labels)

# import torch
# import torchvision

# import torchvision.transforms as transforms
# import numpy as np

# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# # Load CIFAR-10 dataset
# transform = transforms.Compose(
#     [transforms.ToTensor(),])

# trainset = torchvision.datasets.SVHN(root='./data', split='train',
#                                         download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

# # Extract data and labels
# train_data = []
# train_labels = []
# for data, target in trainloader:
#     data, target = data.to(device), target.to(device)
#     train_data.append(data.cpu().numpy())
#     train_labels.append(target.cpu().numpy())

# # Concatenate data and labels
# train_data = np.concatenate(train_data, axis=0)
# train_data = np.transpose(train_data, (0, 2, 3, 1))
# train_labels = np.concatenate(train_labels, axis=0)

# # Save to NumPy file
# np.save('SVHN_train_data.npy', train_data)
# np.save('SVHN_train_labels.npy', train_labels)