import tensorflow as tf


class TokenLoss(tf.keras.metrics.Metric):
    def __init__(self, seq_len, name="token_loss", mask_value=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.seq_len = seq_len
        self.mask_value = mask_value

        self.token_loss = self.add_weight(
            shape=[seq_len],
            name="token_loss",
            initializer="zeros",
            dtype=tf.float32,
        )
        self.n_examples = self.add_weight(
            shape=[seq_len],
            name="n_examples",
            initializer="zeros",
            dtype=tf.int32,
        )

    def update_state(self, labels, logits, sample_weight=None):
        labels = tf.cast(labels, dtype=tf.int32)
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits
        )
        if self.mask_value is None:
            n_examples = tf.fill([self.seq_len], value=tf.shape(losses)[0])
        else:
            mask = labels != tf.cast(self.mask_value, dtype=tf.int32)
            losses *= tf.cast(mask, dtype=losses.dtype)
            n_examples = tf.reduce_sum(tf.cast(mask, dtype=tf.int32), axis=0)

        self.token_loss.assign_add(tf.reduce_sum(losses, axis=0))
        self.n_examples.assign_add(n_examples)

    def result(self):
        return self.loss_by_position()

    def reset_states(self):
        self.token_loss.assign(0.0)
        self.n_examples.assign(0)

    def loss_by_position(self):
        """Compute the loss for each token based on position in the sequence."""
        # Any position with 0 examples will have loss 0; don't divide 0 by 0
        divisor = tf.math.maximum(self.n_examples, 1)
        return self.token_loss / tf.cast(divisor, self.token_loss.dtype)

    def overall_loss(self):
        """Compute the overall loss per token across all positions in the sequence."""
        return tf.reduce_sum(self.token_loss) / tf.cast(
            tf.reduce_sum(self.n_examples), dtype=self.token_loss.dtype
        )
