import sys
import numpy as np
from sklearn.metrics import f1_score


def evaluate(true_label_matrix, pred_prob_matrix, metric, weights, one_dim=False, pos_label=1, average='micro'):
    if metric == 'accuracy':
        if one_dim:
            value = np.dot(true_label_matrix == pred_prob_matrix, weights) / weights.sum()
        else:
            value = np.dot(np.argmax(true_label_matrix, axis=1) == np.argmax(pred_prob_matrix, axis=1), weights) / weights.sum()
    elif metric == 'f1':
        if one_dim:
            value = weighted_f1(true_label_matrix, pred_prob_matrix, n_classes=2, pos_label=pos_label, weights=weights, average='binary')
        else:
            n_items, n_classes = pred_prob_matrix.shape
            value = weighted_f1(np.argmax(true_label_matrix, axis=1), np.argmax(pred_prob_matrix, axis=1), n_classes=n_classes, pos_label=pos_label, average=average, weights=weights)
    else:
        raise ValueError("Metric not recognized.")
    return value


def weighted_f1(true, pred, n_classes=2, pos_label=1, average='micro', weights=None):
    """
    Override f1_score in sklearn in order to deal with both binary and multiclass cases
    :param true: true labels
    :param pred: predicted labels
    :param n_classes: total number of different possible labels
    :param pos_label: label to use as the positive label for the binary case (0 or 1)
    :param average: how to calculate f1 for the multiclass case (default = 'micro')

    :return: f1 score
    """

    if n_classes == 2:
        if np.sum(true * pred) == 0:
            f1 = 0.0
        else:
            f1 = f1_score(true, pred, average='binary', labels=range(n_classes), pos_label=pos_label, sample_weight=weights)
    else:
        if average is None:
            f1 = f1_score(true, pred, average='micro', labels=range(n_classes), pos_label=None, sample_weight=weights)
        else:
            f1 = f1_score(true, pred, average=average, labels=range(n_classes), pos_label=None, sample_weight=weights)
    return f1


def check_improvement(old_val, new_val, metric):
    if metric == 'accuracy':
        return new_val > old_val
    elif metric == 'f1':
        return new_val > old_val
    elif metric == 'calibration':
        return new_val < old_val
    elif metric == 'mae':
        return new_val < old_val
    else:
        print("Metric not recognized")
        sys.exit()


