import torch
from sklearn.metrics import average_precision_score, roc_auc_score
import numpy as np
import torch.nn as nn


def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
    """
    get metrics for the link prediction task
    :param predicts: Tensor, shape (num_samples, )
    :param labels: Tensor, shape (num_samples, )
    :return:
        dictionary of metrics {'metric_name_1': metric_1, ...}
    """
    predicts = predicts.cpu().detach().numpy()
    labels = labels.cpu().numpy()

    average_precision = average_precision_score(y_true=labels, y_score=predicts)
    roc_auc = roc_auc_score(y_true=labels, y_score=predicts)

    return {'average_precision': average_precision, 'roc_auc': roc_auc}


def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
    """
    get metrics for the node classification task
    :param predicts: Tensor, shape (num_samples, )
    :param labels: Tensor, shape (num_samples, )
    :return:
        dictionary of metrics {'metric_name_1': metric_1, ...}
    """
    predicts = predicts.cpu().detach().numpy()
    labels = labels.cpu().numpy()

    roc_auc = roc_auc_score(y_true=labels, y_score=predicts)

    return {'roc_auc': roc_auc}

class LossFunction:
    def __init__(self, loss_type):
        self.loss_type = loss_type

    def forward(self, positive_logits, negative_logits):
        if self.loss_type == 'pointwise':
            predicts = torch.cat([positive_logits, negative_logits])
            labels = torch.cat([torch.ones_like(positive_logits), torch.zeros_like(negative_logits)])
            loss = torch.nn.functional.binary_cross_entropy_with_logits(input=predicts, target=labels)
        elif self.loss_type == 'listwise':
            neg_per_edge = len(negative_logits) // len(positive_logits)
            assert neg_per_edge * len(positive_logits) == len(negative_logits)
            logits = torch.cat([positive_logits[:, None], negative_logits.reshape(-1, neg_per_edge)], dim=1)
            labels = torch.zeros_like(positive_logits).to(dtype=torch.long)
            loss = torch.nn.functional.cross_entropy(input=logits, target=labels)
        else:
            raise ValueError("Not Implemented Loss Type")
        return loss
