from sklearn import metrics
import torch
import numpy as np
import pandas as pd

from graph_learning.tasker import Tasker, TaskerConfig, DataloaderTasker
from graph_learning.tasker.utils import cop_namespace

import graph_learning.utils as u

import dgl
from graph_learning.dataset.graph import GLGraph, gl_batch, gl_unbatch
from .data import GraphDataMixin

@TaskerConfig.register('node-cls',
                       help='[Interface] for node classification task.')
class SG_NodeClsConfig(TaskerConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return SG_NodeCls

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--no-mask', action='store_true')

class SG_NodeCls(GraphDataMixin,
                 Tasker):
    def __init__(self, no_mask):
        super().__init__()
        self.no_mask = no_mask

    def class_num(self, data):
        return data.class_num

    def node_labels(self, data):
        return self.graph(data).ndata['labels']

    def node_mask(self, data):
        graph = self.graph(data)

        if self.no_mask:
            return torch.ones(graph.number_of_nodes(), device=graph.device, dtype=torch.bool)
        map_mode_key = {
            'train': 'train_mask',
            'valid': 'val_mask',
            'test': 'test_mask',
        }
        return graph.ndata[map_mode_key[data.gdata['mode']]]

@TaskerConfig.register('node-cls-metrics',
                       help='[Metrics] for node classification task. (accuracy & macro f1)')
class NodeClsMetricsTaskerConfig(TaskerConfig):
    @property
    def builder(self):
        return NodeClsMetricsTasker

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class NodeClsMetricsTasker(Tasker):
    def __init__(self, logger):
        super().__init__()
        self.logger = logger

    def _cls_metrics(self, labels, preds, name):
        """ Some classification metrics
        """
        ret = {}
        self.logger.log(name, metrics.classification_report(labels, preds))
        ret.update({'accuracy': np.asscalar(metrics.accuracy_score(labels, preds))})
        ret.update({'macro_f1': np.asscalar(metrics.f1_score(labels, preds, average='macro'))})
        return ret

    @cop_namespace('node-cls')
    def valid_metrics(self, data, outputs):
        node_mask = data.node_mask()
        labels = data.node_labels()[node_mask].cpu().numpy().round()
        preds = outputs['probs'][node_mask].argmax(1).cpu().numpy()
        return self._cls_metrics(labels, preds, 'valid')

    @cop_namespace('node-cls')
    def test_metrics(self, data, outputs):
        node_mask = data.node_mask().cpu().numpy()
        labels = data.node_labels().cpu().numpy().round()
        probs = outputs['probs'].cpu().numpy()
        preds = probs.argmax(1)
        return self._cls_metrics(labels[node_mask], preds[node_mask], 'test')

@TaskerConfig.register('ogbn-evaluate',
                       help='[Metrics] Open Graph Benchmark node classification evaluation.')
class OgbnEvalTaskerConfig(TaskerConfig):
    @property
    def builder(self):
        return OgbnEvalTasker

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class OgbnEvalTasker(Tasker):
    def __init__(self):
        super().__init__()

    def _ogbn_metrics(self, data, outputs):
        node_mask = data.node_mask()
        pred = outputs['probs'].argmax(1)[node_mask, None]
        labels = data.node_labels()[node_mask, None]
        evaluator = data.graph().gdata['evaluator']
        return evaluator.eval({
            'y_true': labels,
            'y_pred': pred,
        })

    @cop_namespace('ogbn')
    def valid_metrics(self, data, outputs):
        return self._ogbn_metrics(data, outputs)

    @cop_namespace('ogbn')
    def test_metrics(self, data, outputs):
        return self._ogbn_metrics(data, outputs)
