import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU, SparseCategoricalAccuracy
from tensorflow.keras.losses import SparseCategoricalCrossentropy

class MeanIoUIgnore(MeanIoU):
    def __init__(self, num_classes, name='miou_ignore', dtype=None, ignore_label=255):
        super(MeanIoUIgnore, self).__init__(num_classes, name=name, dtype=dtype)
        self.ignore_label = ignore_label
        self.ignore_class = num_classes

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        not_ignore = tf.not_equal(y_true, self.ignore_label)
        y_pred = tf.boolean_mask(y_pred, not_ignore, axis=0)
        y_true = tf.boolean_mask(y_true, not_ignore)
        y_pred = tf.argmax(y_pred, axis=-1)
        super(MeanIoUIgnore, self).update_state(y_true, y_pred)

class SparseCategoricalAccuracyIgnore(SparseCategoricalAccuracy):
    def __init__(self, name='sparse_categorical_accuracy_ignore', dtype=None, ignore_label=255):
        super(SparseCategoricalAccuracyIgnore, self).__init__(name=name, dtype=dtype)
        self.ignore_label = ignore_label

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        not_ignore = tf.not_equal(y_true, self.ignore_label)
        y_pred = tf.boolean_mask(y_pred, not_ignore, axis=0)
        y_true = tf.boolean_mask(y_true, not_ignore)
        super(SparseCategoricalAccuracyIgnore, self).update_state(y_true, y_pred)

class SparseCategoricalCrossentropyIgnore(SparseCategoricalCrossentropy):
    def __init__(self, class_weight=None, from_logits=False, name='sparse_categorical_crossentropy_ignore', ignore_label=255):
        super(SparseCategoricalCrossentropyIgnore, self).__init__(from_logits=from_logits, name=name)
        if class_weight is None:
            self.class_weight = class_weight
        else:
            self.class_weight = tf.convert_to_tensor(class_weight, dtype=tf.float32)
        self.ignore_label = ignore_label

    def call(self, y_true, y_pred):
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        not_ignore = tf.not_equal(y_true, self.ignore_label)
        y_pred = tf.boolean_mask(y_pred, not_ignore, axis=0)
        y_true = tf.boolean_mask(y_true, not_ignore)
        if self.class_weight is not None:
            weight = tf.gather(self.class_weight, tf.cast(y_true, tf.int32))
        else:
            weight = 1.0

        return super(SparseCategoricalCrossentropyIgnore, self).call(y_true, y_pred) * weight
