import torch
from torch import nn
from torch.nn import functional as F
from sklearn.metrics import average_precision_score, accuracy_score, f1_score


def acc_f1(output, labels, average='binary'):
    preds = output.max(1)[1].type_as(labels)
    if preds.is_cuda:
        preds = preds.cpu()
        labels = labels.cpu()
    accuracy = accuracy_score(preds, labels)
    f1 = f1_score(preds, labels, average=average)
    return accuracy, f1


class TaskModel(nn.Module):
    def __init__(self, args, encoder) -> None:
        super().__init__()

        self.args = args
        self.encoder = encoder
        if args.n_classes > 2:
            self.f1_average = 'micro'
        else:
            self.f1_average = 'binary'

        self.weights = nn.parameter.Parameter(torch.tensor([1.] * args.n_classes), requires_grad=False)
        # self.weights = torch.Tensor([1.] * args.n_classes)
        # if not args.cuda == -1:
        #     self.weights = self.weights.to(args.device)

    def compute_metrics(self, embeddings, data, split):
        idx = data[f'idx_{split}']
        # output = self.decode(embeddings, data['adj_train_norm'], idx)
        output = self.encoder.logit(embeddings[idx])
        loss = F.nll_loss(output, data['labels'][idx], self.weights)
        acc, f1 = acc_f1(output, data['labels'][idx], average=self.f1_average)
        metrics = {'loss': loss, 'acc': acc, 'f1': f1}
        return metrics

    def init_metric_dict(self):
        return {'acc': -1, 'f1': -1}

    def has_improved(self, m1, m2):
        return m1["f1"] < m2["f1"]

