import tensorflow as tf
from sklearn.metrics import f1_score
import numpy as np


def masked_softmax_cross_entropy(preds, labels, mask):
    """Softmax cross-entropy loss with masking."""
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    loss *= mask
    return tf.reduce_mean(loss)


def masked_accuracy(preds, labels, mask):
    """Accuracy with masking."""
    correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask
    return tf.reduce_mean(accuracy_all)


def softmax_cross_entropy(preds, labels):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
    return tf.reduce_mean(loss)


def accuracy(preds, labels):
    correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, tf.float32)

    #    f1s = [0, 0, 0]
    #
    #    y_true = tf.cast(labels, tf.float64)
    #    y_pred = tf.cast(preds, tf.float64)
    #
    #    for i, axis in enumerate([None, 0]):
    #        TP = tf.count_nonzero(y_pred * y_true, axis=axis)
    #        FP = tf.count_nonzero(y_pred * (y_true - 1), axis=axis)
    #        FN = tf.count_nonzero((y_pred - 1) * y_true, axis=axis)
    #
    #        precision = TP / (TP + FP)
    #        recall = TP / (TP + FN)
    #        f1 = 2 * precision * recall / (precision + recall)
    #
    #        f1s[i] = tf.reduce_mean(f1)
    #
    #    weights = tf.reduce_sum(y_true, axis=0)
    #    weights /= tf.reduce_sum(weights)
    #
    #    f1s[2] = tf.reduce_sum(f1 * weights)
    #
    #    micro, macro, weighted = f1s

    return tf.reduce_mean(accuracy_all)


#    print("calc micro f1 score")
#    return micro

def f1score(y_true, y_pred):
    """Computes 3 different f1 scores, micro macro
    weighted.
    micro: f1 score accross the classes, as 1
    macro: mean of f1 scores per class
    weighted: weighted average of f1 scores per class,
              weighted from the support of each class
    Args:
        y_true (Tensor): labels, with shape (batch, num_classes)
        y_pred (Tensor): model's predictions, same shape as y_true
    Returns:
        tupe(Tensor): (micro, macro, weighted)
                      tuple of the computed f1 scores
    """

    f1s = [0, 0, 0]

    y_true = tf.cast(y_true, tf.float64)
    y_pred = tf.cast(y_pred, tf.float64)

    for i, axis in enumerate([None, 0]):
        TP = tf.count_nonzero(y_pred * y_true, axis=axis)
        FP = tf.count_nonzero(y_pred * (y_true - 1), axis=axis)
        FN = tf.count_nonzero((y_pred - 1) * y_true, axis=axis)

        precision = TP / (TP + FP)
        recall = TP / (TP + FN)
        f1 = 2 * precision * recall / (precision + recall)

        f1s[i] = tf.reduce_mean(f1)

    weights = tf.reduce_sum(y_true, axis=0)
    weights /= tf.reduce_sum(weights)

    f1s[2] = tf.reduce_sum(f1 * weights)

    micro, macro, weighted = f1s
    return micro
