import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score

"""
Evaluation functions from OGB.
https://github.com/snap-stanford/ogb/blob/master/ogb/graphproppred/evaluate.py
"""


def eval_rocauc(y_true, y_pred):
    """
    compute ROC-AUC averaged across tasks
    """

    rocauc_list = []

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            rocauc_list.append(
                roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
            )

    if len(rocauc_list) == 0:
        raise RuntimeError(
            "No positively labeled data available. Cannot compute ROC-AUC."
        )

    return {"rocauc": sum(rocauc_list) / len(rocauc_list)}


def eval_ap(y_true, y_pred):
    """
    compute Average Precision (AP) averaged across tasks
    """

    ap_list = []

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            ap = average_precision_score(
                y_true[is_labeled, i], y_pred[is_labeled, i]
            )

            ap_list.append(ap)

    if len(ap_list) == 0:
        raise RuntimeError(
            "No positively labeled data available. Cannot compute Average Precision."
        )

    return {"ap": sum(ap_list) / len(ap_list)}


def eval_rmse(y_true, y_pred):
    """
    compute RMSE score averaged across tasks
    """
    rmse_list = []

    for i in range(y_true.shape[1]):
        # ignore nan values
        is_labeled = y_true[:, i] == y_true[:, i]
        rmse_list.append(
            np.sqrt(
                ((y_true[is_labeled, i] - y_pred[is_labeled, i]) ** 2).mean()
            )
        )

    return {"rmse": sum(rmse_list) / len(rmse_list)}


def eval_acc(y_true, y_pred):
    acc_list = []

    for i in range(y_true.shape[1]):
        is_labeled = y_true[:, i] == y_true[:, i]
        correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
        acc_list.append(float(np.sum(correct)) / len(correct))

    return {"acc": sum(acc_list) / len(acc_list)}


def eval_F1(seq_ref, seq_pred):
    # '''
    #     compute F1 score averaged over samples
    # '''

    precision_list = []
    recall_list = []
    f1_list = []

    for l, p in zip(seq_ref, seq_pred):
        label = set(l)
        prediction = set(p)
        true_positive = len(label.intersection(prediction))
        false_positive = len(prediction - label)
        false_negative = len(label - prediction)

        if true_positive + false_positive > 0:
            precision = true_positive / (true_positive + false_positive)
        else:
            precision = 0

        if true_positive + false_negative > 0:
            recall = true_positive / (true_positive + false_negative)
        else:
            recall = 0
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0

        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)

    return {
        "precision": np.average(precision_list),
        "recall": np.average(recall_list),
        "F1": np.average(f1_list),
    }
