import torch

from graph_learning.dataset.graph import GLGraph
from graph_learning.dataset import DatasetConfig
from ogb.nodeproppred import DglNodePropPredDataset
from ogb.nodeproppred import Evaluator

@DatasetConfig.register('ogbn-arxiv',
                        help='Ogbn-arxiv dataset.')
class ObgArxivDatasetConfig(DatasetConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--self-loop', action='store_true')

    def build_dataset(self):
        dataset = DglNodePropPredDataset(name='ogbn-arxiv')
        graph = dataset.graph[0]
        graph.ndata['x'] = graph.ndata['feat']
        graph.ndata['labels'] = dataset.labels.squeeze(-1)
        del graph.ndata['feat']
        def get_mask(index):
            return torch.zeros(graph.number_of_nodes(), dtype=torch.bool).scatter(0, index, 1)
        indices = dataset.get_idx_split()
        graph.ndata['train_mask'] = get_mask(indices['train'])
        graph.ndata['val_mask'] = get_mask(indices['valid'])
        graph.ndata['test_mask'] = get_mask(indices['test'])

        srcs, dsts = graph.all_edges()
        graph.add_edges(dsts, srcs)
        graph = graph.remove_self_loop()
        if self.self_loop:
            graph = graph.add_self_loop()

        graph = GLGraph(graph)
        graph.gdata['class_num'] = dataset.num_classes
        graph.gdata['name'] = 'ogbn_arxiv'
        evaluator = Evaluator(name='ogbn-arxiv')
        graph.gdata['evaluator'] = evaluator
        return graph
