"""List of loss functions for inference."""

from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

FLAGS = flags.FLAGS
EPS = 1e-7


@tf.function
def mae(model_outputs, targets):
  """Mae loss function.

  Args:
    model_outputs: t * b * 1
    targets: t * b

  Returns:
    scalar mae loss
  """
  mae_loss = tf.abs(model_outputs[:, :, 0] - targets)
  return tf.reduce_mean(mae_loss)


@tf.function
def mape(model_outputs, targets):
  """Mape loss function.

  Args:
    model_outputs: t * b * 1
    targets: t * b

  Returns:
    scalar mape loss
  """
  safe_targets = tf.where(tf.equal(targets, 0.), tf.ones_like(targets), targets)
  outputs = model_outputs[:, :, 0]
  safe_outputs = tf.where(tf.equal(targets, 0.), tf.ones_like(targets), outputs)
  weights = tf.where(
      tf.equal(targets, 0.), tf.zeros_like(targets), tf.ones_like(targets))
  mape_loss = tf.abs(tf.divide(safe_outputs - safe_targets, safe_targets))
  weighted_loss = tf.math.divide_no_nan(mape_loss * weights,
                                        tf.reduce_sum(weights))
  return tf.reduce_sum(weighted_loss)


@tf.function
def mse(model_outputs, targets):
  """mse loss function.

  Args:
    model_outputs: t * b * 1
    targets: t * b

  Returns:
    scalar mse loss
  """
  mse_loss = tf.square(model_outputs[:, :, 0] - targets)
  return tf.reduce_mean(mse_loss)


@tf.function
def huber_loss(model_outputs, targets):
  """mse loss function.

  Args:
    model_outputs: t * b * 1
    targets: t * b

  Returns:
    scalar huber loss
  """
  delta = FLAGS.huber_delta
  abs_dev = tf.abs(targets - model_outputs[:, :, 0])
  l2_loss = 0.5 * tf.square(abs_dev)
  dev_loss = delta * abs_dev - 0.5 * delta * delta
  loss = tf.where(tf.less_equal(abs_dev, delta), l2_loss, dev_loss)
  return tf.reduce_mean(loss)


@tf.function
def quantile_loss(model_outputs, targets):
  """quantile loss function.

  Args:
    model_outputs: t * b * 1
    targets: t * b

  Returns:
    scalar huber loss
  """
  qt = FLAGS.quantile
  dev = targets - model_outputs[:, :, 0]
  loss_first = dev * qt
  loss_second = -dev * (1 - qt)
  loss = 2 * tf.where(tf.greater_equal(loss_first, 0), loss_first, loss_second)
  return tf.reduce_mean(loss)


@tf.custom_gradient
def convert_to_positive(x: tf.Tensor):
  """Convert input value range to positive range and keep the increasing order.

  The conversion function has two segments:
    1. x > 0: output is x + 1;
    2. x <= 0: output is 1/(1-x);
  The shape of the function is close to softplus, but it avoids log() and exp()
  operation, and converge to zero slower for negative inputs.

  Args:
    x: a floating point tensor.

  Returns:
    A tuple of the converted output, and the gradient function.
  """
  y = tf.where(tf.greater(x, 0.), x + 1., 1. / (1. - x))

  def grad(dy: tf.Tensor) -> tf.Tensor:
    """Compute gradient of the input at x.

    When x > 0: y = x+1 > 1, dy/dx = 1;
    When x <= 0: y = (1-x)^(-1) <=1, dy/dx = (1-x)^(-2) = y^2.

    Args:
      dy: the initial value gradients for y.

    Returns:
      The derivatives of y with respect to x.
    """
    return dy * tf.where(tf.greater(y, 1.), tf.ones_like(y), y**2)

  return y, grad


@tf.function
def znb_mle(model_outputs, targets):
  """zero inflated negative binomial mle."""
  zero_logits = model_outputs[:, :, 0]
  num = convert_to_positive(model_outputs[:, :, 1])
  logits = model_outputs[:, :, 2]

  # Assuming x to be zero_logits, and z to be mu and alpha below.
  #   Loss(y = 0)
  # = - log(pi + (1 - pi) * prob(0; z))
  # = - log(sigmoid(x)) - log(1 + exp(-x) * prob(0; z))
  # = - log(sigmoid(x)) - log(1 + exp(-x) * exp(log_prob(0; z)))
  # = log(1 + exp(-x)) - log(1 + exp(-x + log_prob(0; z)))
  # = softplus(-x) - softplus(-x + log_prob(0; z)).
  log_prob_zero_neg_binomial = tfp.distributions.NegativeBinomial(
      total_count=num, logits=logits).log_prob(tf.zeros_like(targets))
  loss_zero = tf.math.softplus(-zero_logits) - tf.math.softplus(
      -zero_logits + log_prob_zero_neg_binomial)
  # Avoids nan when targets are negative.
  targets = tf.maximum(0.0, targets)
  safe_targets = tf.where(tf.equal(targets, 0.), tf.ones_like(targets), targets)

  #   Loss(y != 0)
  # = - log((1 - pi) * prob(y; z))
  # = - log((1 - sigmoid(x))) - log_prob(y; z)
  # = softplus(x) - log_prob(y; z)
  log_prob_non_zero_neg_binomial = tfp.distributions.NegativeBinomial(
      total_count=num, logits=logits).log_prob(safe_targets)
  loss_non_zero = tf.math.softplus(zero_logits) - log_prob_non_zero_neg_binomial

  # Combining zero and non-zero loss.
  loss = tf.where(tf.equal(targets, 0.), loss_zero, loss_non_zero)
  return tf.reduce_mean(loss)


@tf.function
def nb_mle(model_outputs, targets):
  """zero inflated negative binomial mle."""
  num = convert_to_positive(model_outputs[:, :, 0])
  logits = model_outputs[:, :, 1]
  targets = tf.maximum(0.0, targets)

  loss_nb = -tfp.distributions.NegativeBinomial(
      total_count=num, logits=logits).log_prob(targets)
  return tf.reduce_mean(loss_nb)


@tf.function
def mix_znbp_mle(model_outputs, targets):
  """Mixture of znb and pareto."""
  mixture_probs = tf.nn.softmax(model_outputs[:, :, 0:3], axis=2)
  num = convert_to_positive(model_outputs[:, :, 3])
  logits = model_outputs[:, :, 4]
  m = tf.nn.softplus(model_outputs[:, :, 5]) + 1e-3
  alpha = FLAGS.alpha

  prob_zero_nb = tfp.distributions.NegativeBinomial(
      total_count=num, logits=logits).prob(tf.zeros_like(targets))
  loss_zero = -tf.math.log(EPS + mixture_probs[:, :, 0] +
                           mixture_probs[:, :, 1] * prob_zero_nb)
  # Avoids nan when targets are negative.
  safe_targets = tf.where(
      tf.less_equal(targets, 0.), tf.ones_like(targets), targets)

  prob_non_zero_nb = tfp.distributions.NegativeBinomial(
      total_count=num, logits=logits).prob(safe_targets)
  prob_non_zero_ln = tfp.distributions.Pareto(
      concentration=alpha, scale=m).prob(safe_targets) + 1e-20
  loss_non_zero = -tf.math.log(EPS + mixture_probs[:, :, 1] * prob_non_zero_nb +
                               mixture_probs[:, :, 2] * prob_non_zero_ln)

  # Combining zero and non-zero loss.
  loss = tf.where(tf.equal(targets, 0.), loss_zero, loss_non_zero)
  return tf.reduce_mean(loss)



LOSS_DICT = {
    'mae': mae,
    'mse': mse,
    'znb_mle': znb_mle,
    'nb_mle': nb_mle,
    'mape': mape,
    'huber_loss': huber_loss,
    'quantile_loss': quantile_loss,
    'mix_znbp_mle': mix_znbp_mle,
}
