import torch
from torch_geometric.loader import NeighborSampler
from torch_geometric.datasets import Reddit

device = "cuda"


def reddit(layer: int=1):
    dataset = Reddit(root=r"./dataset")[0]
    train_idx = torch.LongTensor([i for i, _ in enumerate(dataset.train_mask.numpy()) if _])
    valid_idx = torch.LongTensor([i for i, _ in enumerate(dataset.val_mask.numpy()) if _])
    test_idx = torch.LongTensor([i for i, _ in enumerate(dataset.test_mask.numpy()) if _])
    return (dataset.to(device),
            NeighborSampler(edge_index=dataset.edge_index, sizes=[35, 20], node_idx= train_idx, shuffle=True, batch_size=512),
            "",
            # NeighborSampler(edge_index=dataset.edge_index, sizes=[30, 20, 10], node_idx= valid_idx, shuffle=True, batch_size=512),
            NeighborSampler(edge_index=dataset.edge_index, sizes=[35, 20], node_idx= test_idx, shuffle=False, batch_size=512))



if __name__ == '__main__':
    reddit()