# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Loss functions."""
import enum
from typing import Tuple, Mapping, Optional, Union

from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np


@jax.custom_vjp
def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray,
                              z_loss: float) -> jnp.ndarray:
  """Computes cross entropy loss with stable custom gradient.

  Computes a stabilized-gradient version of:
    -jnp.sum(targets * nn.log_softmax(logits), axis=-1)

  If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2
  will be added to the cross entropy loss (z = softmax normalization constant).
  The two uses of z_loss are:
  1. To keep the logits from drifting too far from zero, which can cause
     unacceptable roundoff errors in bfloat16.
  2. To encourage the logits to be normalized log-probabilities.

  Args:
    logits: [batch, length, num_classes] float array.
    targets: categorical one-hot targets [batch, length, num_classes] float
      array.
    z_loss: coefficient for auxilliary z-loss loss term.

  Returns:
    tuple with the total loss and the z_loss, both
    float arrays with shape [batch, length].
  """
  logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
  log_softmax = logits - logits_sum
  loss = -jnp.sum(targets * log_softmax, axis=-1)
  # Add auxilliary z-loss term.
  log_z = jnp.squeeze(logits_sum, axis=-1)
  total_z_loss = z_loss * jax.lax.square(log_z)
  loss += total_z_loss
  return loss, total_z_loss


def _cross_entropy_with_logits_fwd(
    logits: jnp.ndarray,
    targets: jnp.ndarray,
    z_loss: float = 0.0
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp
                              .ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
  """Forward-mode of `cross_entropy_with_logits`."""
  max_logit = logits.max(axis=-1, keepdims=True)
  shifted = logits - max_logit
  exp_shifted = jnp.exp(shifted)
  sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
  log_softmax = shifted - jnp.log(sum_exp)
  loss = -jnp.sum(targets * log_softmax, axis=-1)
  # Add auxilliary z-loss term.
  log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
  total_z_loss = z_loss * jax.lax.square(log_z)
  loss += total_z_loss
  return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp,
                                log_softmax, log_z)


def _cross_entropy_with_logits_bwd(
    res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray,
               jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Backward-mode of `cross_entropy_with_logits`."""
  g = g[0]  # Ignore z_loss component as that is only used for logging.
  logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res
  # z-loss term adds the (2 * z_loss * log_z) factor.
  deriv = (
      jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp -
      targets)
  g_logits = jnp.expand_dims(g, axis=-1) * deriv
  g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax
  return (jnp.asarray(g_logits,
                      logits.dtype), jnp.asarray(g_targets, targets.dtype),
          jnp.array(0.0))  # sets z-loss coeff gradient to 0


cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd,
                                 _cross_entropy_with_logits_bwd)


def compute_weighted_cross_entropy(
    logits: jnp.ndarray,
    targets: jnp.ndarray,
    weights: Optional[jnp.ndarray] = None,
    label_smoothing: float = 0.0,
    z_loss: float = 0.0,
    loss_normalizing_factor: Optional[float] = None
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: label smoothing constant, used to determine the on and off
     values.
   z_loss: coefficient for auxiliary z-loss loss term.
   loss_normalizing_factor: Constant to divide loss by. If not specified, loss
     will not be normalized. Intended for backward compatibility with T5-MTF
     training. Should not normally be used.

  Returns:
    Tuple of scalar loss, z_loss, and weight sum.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  vocab_size = logits.shape[-1]
  confidence = 1.0 - label_smoothing
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
  normalizing_constant = -(
      confidence * jnp.log(confidence) +
      (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
  soft_targets = common_utils.onehot(
      targets, vocab_size, on_value=confidence, off_value=low_confidence)
  total_loss, total_z_loss = cross_entropy_with_logits(
      logits, soft_targets, z_loss=z_loss)
  total_loss = total_loss - normalizing_constant

  weight_sum = np.prod(targets.shape)
  if weights is not None:
    total_loss = total_loss * weights
    total_z_loss = total_z_loss * weights
    weight_sum = jnp.sum(weights)

  # By default, we do not normalize loss based on anything.
  # We don't normalize based on batch size because the optimizers we use are
  # pretty much scale invariant, so this simplifies things.
  # We don't normalize based on number of non-padding tokens in order to treat
  # each token as equally important regardless of sequence length.
  if loss_normalizing_factor is not None:
    total_loss /= loss_normalizing_factor
    total_z_loss /= loss_normalizing_factor
  return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum


@enum.unique
class SpecialLossNormalizingFactor(enum.Enum):
  """Specially calcualted loss_normalizing_factors, that are not a constant.

  Attributes:
    NUM_REAL_TARGET_TOKENS: Whether to divide the loss by the number of real
      (non-padding) tokens in the current target batch. If
      'decoder_loss_weights' are specified, it will be the sum of the weights.
      Otherwise it will be the number of non-zero 'decoder_target_tokens'.
    NUM_TOTAL_TARGET_TOKENS: Whether to divide the loss by the total number of
      target tokens, i.e., batch_size * target_seq_length (including padding).
    AVERAGE_PER_SEQUENCE: This will first compute the per-sequence loss
      (averaged over the number of real target tokens in the sequence), and then
      compute the average of that over the sequences. This can be preferable to
      NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples
      equally, regardless of sequence length (which can be especially important
      for multi-task finetuning).
  """
  NUM_REAL_TARGET_TOKENS = 1
  NUM_TOTAL_TARGET_TOKENS = 2
  AVERAGE_PER_SEQUENCE = 3


def convert_special_loss_normalizing_factor_to_enum(
    x: str) -> SpecialLossNormalizingFactor:
  """Converts stringified version of LNF to an enum.

  This is useful because gin dynamic registration does not (currently)
  have support for enum.

  Args:
    x: stringified version of SpecialLossNormalizingFactor enum.

  Returns:
    SpecialLossNormalizingFactor enum instance.
  """
  x = x.upper()
  if x == 'NUM_REAL_TARGET_TOKENS':
    return SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS
  if x == 'NUM_TOTAL_TARGET_TOKENS':
    return SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS
  if x == 'AVERAGE_PER_SEQUENCE':
    return SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE
  raise ValueError(
      'Could not convert string \"%s\" to SpecialLossNormalizingFactor' % x)


def get_loss_normalizing_factor_and_weights(
    loss_normalizing_factor: Optional[Union[float, int, str,
                                            SpecialLossNormalizingFactor]],
    batch: Mapping[str, jnp.ndarray]):
  """Get the float loss_normalizing_factor and loss weights.

  If loss_normalizing_factor is float or None, this will simply return the
  input loss_normalizing_factor and batch.

  If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will
  return a float loss_normalizing_factor and loss weights corresponding to
  the special LNF. See SpecialLossNormalizingFactor for more details.

  Args:
    loss_normalizing_factor: The input LNF, which may be a float, None, or
      SpecialLossNormalizingFactor (or a stringified SLNF).
    batch: Input data batch.

  Returns:
    Tuple of (output_loss_normalizing_factor, loss_weights).
      'output_loss_normalizing_factor' is a scalar float (Python float
      or jnp float).
      'loss_weights' is the per token loss weight JNP array.
  """

  loss_weights = batch.get('decoder_loss_weights', None)
  if (loss_normalizing_factor is None or
      not isinstance(loss_normalizing_factor,
                     (str, SpecialLossNormalizingFactor))):
    return (loss_normalizing_factor, loss_weights)

  if isinstance(loss_normalizing_factor, str):
    loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum(
        loss_normalizing_factor)

  # If `loss_weights` are not provided, we assume that the padding id is 0 and
  # that non-padding tokens in the decoder all correspond to the positions
  # where loss should be taken. If more fine-grained behavior (e.g., taking
  # loss on subset of 'decoder_target_tokens') is desired, provide
  # `loss_weights` that account for this.
  if loss_weights is None:
    loss_weights = jnp.asarray(batch['decoder_target_tokens'] > 0, jnp.float32)

  output_normalizing_factor = None
  if (loss_normalizing_factor ==
      SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS):
    output_normalizing_factor = jnp.sum(loss_weights)
  elif (loss_normalizing_factor ==
        SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS):
    output_normalizing_factor = np.prod(batch['decoder_target_tokens'].shape)
  elif (loss_normalizing_factor ==
        SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE):
    loss_weights /= jnp.sum(loss_weights, axis=-1, keepdims=True) + 1e-3
    output_normalizing_factor = jnp.sum(loss_weights)
  else:
    raise ValueError('Unsupported value of loss_normalizing_factor: %s' %
                     str(loss_normalizing_factor))

  return (output_normalizing_factor, loss_weights)
