import os
import argparse
import numpy as np

import torch
from data_loader import load_data
from fedgls import train_fedgls


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str, default='Cora',
                        help='select one dataset used for training: Cora, Citeseer, PubMed, ogbn-arxiv')
    parser.add_argument('--data_path', type=str, default='./data',
                        help='the folder to store datasets')
    parser.add_argument('--num_graphless', type=int, default='4',
                        help='the number of graphless clients')
    parser.add_argument('--rounds', type=int, default='200',
                        help='the number of rounds for training')
    parser.add_argument('--epochs', type=int, default='5',
                        help='the number of local epochs (default: 5)')
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate (default: 0.01)')
    parser.add_argument('--num_hidden', type=int, default='16',
                        help='hidden size (default: 16)')
    parser.add_argument('--graph_learner', type=str, default='Attentive',
                        help='graph learner: Attentive or MLP (default: Attentive)')
    parser.add_argument('--k', type=int, default='20',
                        help='the number of neighbors (default: 20)')
    parser.add_argument('--seed', type=int, default=0,
                        help='seed')
    args = parser.parse_args()
    return args


def set_random_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


if __name__ == '__main__':
    args = get_args()

    assert args.dataset_name in ['Cora', 'Citeseer', 'PubMed', 'ogbn-arxiv'], \
        'Please use correct datasets: Cora, Citeseer, PubMed, ogbn-arxiv'

    set_random_seed(args.seed)

    partitions_names = os.listdir('./partition')
    dataset, num_clients, trainIdx, valIdx, testIdx = load_data(partitions_names, args)

    best_acc = train_fedgls(dataset, num_clients, trainIdx, valIdx, testIdx, args)
    print('Best acc: {:9.5f}'.format(best_acc))
