from torch.utils.data import DataLoader, Dataset as ptDataset
from sklearn import metrics
import torch
import numpy as np
import pandas as pd

from graph_learning.tasker import Tasker, TaskerConfig, DataloaderTasker
from graph_learning.tasker.utils import cop_namespace

import graph_learning.utils as u
from graph_learning.utils import merge_metrics

@TaskerConfig.register('model-metrics',
                       help='[Metrics] Metrics returned from models.')
class ModelMetricsTaskerConfig(TaskerConfig):
    @property
    def builder(self):
        return ModelMetricsTasker

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class ModelMetricsTasker(Tasker):
    def __init__(self, logger):
        super().__init__()
        self.logger = logger

    def _metrics(self, data, outputs):
        if 'metrics' in outputs:
            return outputs['metrics']
        return {}

    @cop_namespace('metrics')
    def valid_metrics(self, data, outputs):
        return self._metrics(data, outputs)

    @cop_namespace('metrics')
    def test_metrics(self, data, outputs):
        return self._metrics(data, outputs)
