import torch

def get_data_loaders(batch_size, features_train, labels_train):
    train_dataset = torch.utils.data.TensorDataset(features_train, labels_train)
    train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    return train_dataset, train_iter