from torch.utils.data import DataLoader, Dataset as ptDataset
from sklearn import metrics
import torch
import numpy as np
import pandas as pd

from graph_learning.tasker import Tasker, TaskerConfig, DataloaderTasker
from graph_learning.tasker.utils import cop_namespace

import graph_learning.utils as u
from graph_learning.utils import merge_metrics

from .data import GraphDataMixin

@TaskerConfig.register('pairwise-node-cls',
                       help='[Interface] for pairwise node classification task.')
class PairNodeClsConfig(TaskerConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PairNodeCls

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class PairNodeCls(GraphDataMixin, Tasker):
    def pair(self, data):
        return self.graph(data).gdata[f'pair_{data.gdata["mode"]}']

    def pair_labels(self, data):
        return self.graph(data).gdata[f'pair_labels']

    def pair_pos(self, data):
        return self.graph(data).gdata[f'pair_{data.gdata["mode"]}_pos']

    def pair_neg(self, data):
        return self.graph(data).gdata[f'pair_{data.gdata["mode"]}_neg']

@TaskerConfig.register('pairwise-node-cls-metrics',
                       help='[Metrics] for node classification task. (Auc)')
class PairNodeClsMetricsTaskerConfig(TaskerConfig):
    @property
    def builder(self):
        return PairNodeClsMetricsTasker

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class PairNodeClsMetricsTasker(Tasker):
    def __init__(self, logger):
        super().__init__()
        self.logger = logger

    @cop_namespace('pairwise-node-cls')
    def valid_metrics(self, data, outputs):
        labels = outputs['labels'].flatten().cpu().numpy()
        probs = outputs['probs'].flatten().cpu().numpy()
        return {'auc': np.asscalar(metrics.roc_auc_score(labels, probs))}

    @cop_namespace('pairwise-node-cls')
    def test_metrics(self, data, outputs):
        labels = outputs['labels'].flatten().cpu().numpy()
        probs = outputs['probs'].flatten().cpu().numpy()
        return {'auc': np.asscalar(metrics.roc_auc_score(labels, probs))}
