import torch
import numpy as np
import torch.nn.functional as F

from . import BaseMetric

from sklearn.metrics import f1_score, roc_auc_score, recall_score, precision_score


def _to_numpy(output, target):
    target = target.detach().cpu().numpy()
    output = output.detach().cpu().numpy()
    output = np.argmax(output, 1)
    return output, target


class F1(BaseMetric):
    def __init__(self, name: str = 'f1', average: str = 'weighted', **kwargs):
        super().__init__(name)

        self.average = average
        self.kwargs = kwargs

    def call(self, output, target):
        output, target = _to_numpy(output, target)
        return f1_score(target, output, average=self.average,
                        **self.kwargs)


class Precision(BaseMetric):
    def __init__(self, name='precision', average='macro', **kwargs):
        super().__init__(name)
        self.average = average
        self.kwargs = kwargs

    def call(self, output, target):
        output, target = _to_numpy(output, target)
        return precision_score(target, output, average=self.average, **self.kwargs)


class Recall(BaseMetric):
    def __init__(self, name='recall', average='macro', **kwargs):
        super().__init__(name)
        self.average = average
        self.kwargs = kwargs

    def call(self, output, target):
        output, target = _to_numpy(output, target)
        return recall_score(target, output, average=self.average, **self.kwargs)


class AUC(BaseMetric):
    def __init__(self, name: str = 'auc', **kwargs):
        super().__init__(name)
        self.kwargs = kwargs

    def call(self, output, target):
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        assert len(output.shape) == 2
        assert len(target.shape) == 1
        max_output_id = output.argmax(1)
        return roc_auc_score(target, max_output_id, **self.kwargs)
