import numpy as np

from torch_geometric.datasets import Planetoid
from ogb.nodeproppred import PygNodePropPredDataset


def load_data(partitions, args):
    dataset_name = args.dataset_name
    data_path = args.data_path

    # get raw dataset
    if dataset_name in ['Cora', 'Citeseer', 'PubMed']:
        dataset = Planetoid(root=f'{data_path}', name=dataset_name)
    else:
        dataset = PygNodePropPredDataset(root=f'{data_path}', name=dataset_name)

    file_list = []
    num_clients = 0
    for file in partitions:
        if file.find(dataset_name) == 0:
            file_list.append(file)
            num_clients += 1

    # shuffle the order of clients
    np.random.shuffle(file_list)

    trainIdx = []
    valIdx = []
    testIdx = []

    for file in file_list:
        node_list = np.loadtxt(f'./partition/{file}').astype(int)

        # randomly select 30% for training, 30% for validation, and the remaining for test
        np.random.shuffle(node_list)
        trainIdx.append(list(node_list)[: int(0.3 * len(node_list))])
        valIdx.append(list(node_list)[int(0.3 * len(node_list)): int(0.6 * len(node_list))])
        testIdx.append(list(node_list)[int(0.6 * len(node_list)):])

    return dataset, num_clients, trainIdx, valIdx, testIdx