import argparse

from ogb.linkproppred import *
from ogb.nodeproppred import *

from dgl.data import CitationGraphDataset


def load_graph(name):
    cite_graphs = ["cora", "citeseer", "pubmed"]

    if name in cite_graphs:
        dataset = CitationGraphDataset(name)
        graph = dataset[0]

        nodes = graph.nodes()
        y = graph.ndata["label"]
        train_mask = graph.ndata["train_mask"]
        val_mask = graph.ndata["test_mask"]

        nodes_train, y_train = nodes[train_mask], y[train_mask]
        nodes_val, y_val = nodes[val_mask], y[val_mask]
        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]

    elif name.startswith("ogbn"):

        dataset = DglNodePropPredDataset(name)
        graph, y = dataset[0]
        split_nodes = dataset.get_idx_split()
        nodes = graph.nodes()

        train_idx = split_nodes["train"]
        val_idx = split_nodes["valid"]

        nodes_train, y_train = nodes[train_idx], y[train_idx]
        nodes_val, y_val = nodes[val_idx], y[val_idx]
        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]

    else:
        raise ValueError("Dataset name error!")

    return graph, eval_set


def parse_arguments():
    """
    Parse arguments
    """
    parser = argparse.ArgumentParser(description="Node2vec")
    parser.add_argument("--dataset", type=str, default="cora")
    # 'train' for training node2vec model, 'time' for testing speed of random walk
    parser.add_argument("--task", type=str, default="train")
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--embedding_dim", type=int, default=128)
    parser.add_argument("--walk_length", type=int, default=50)
    parser.add_argument("--p", type=float, default=0.25)
    parser.add_argument("--q", type=float, default=4.0)
    parser.add_argument("--num_walks", type=int, default=10)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=128)

    args = parser.parse_args()

    return args
