import torch
import numpy as np
import torch.nn.functional as F

from . import BaseMetric

from sklearn.metrics import f1_score, roc_auc_score


class F1(BaseMetric):
    def __init__(self, name: str = 'f1', average: str = 'weighted', zero_division: float = 0, apply=lambda x: x, **kwargs):
        super().__init__(name)

        self.average = average
        self.zero_division = zero_division
        self.kwargs = kwargs
        self.apply = apply

    def call(self, output, target):
        # print(output)
        output = self.apply(output)
        # print(output)
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        assert len(output.shape) == 2
        assert len(target.shape) == 1
        max_output_id = output.argmax(1)
        return f1_score(target, max_output_id, average=self.average,
                        zero_division=self.zero_division, **self.kwargs)


class AUC(BaseMetric):
    def __init__(self, name: str = 'auc', **kwargs):
        super().__init__(name)
        self.kwargs = kwargs

    def call(self, output, target):
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        assert len(output.shape) == 2
        assert len(target.shape) == 1
        max_output_id = output.argmax(1)
        return roc_auc_score(target, max_output_id, **self.kwargs)
