import tensorflow as tf
from tensorflow.keras import losses


def log10(x):
  return tf.math.log(x) / tf.math.log(tf.constant(10, dtype=x.dtype))


def radial_basis_function(x, center=0, bandwidth=0.1):
  return tf.math.exp(-1 / (2 * tf.square(bandwidth)) * tf.square(x - center))


def per_sample_mean(x: tf.Tensor):
  """ return the per sample mean of x
  Args:
    x: tensor with shape (N, ...) where N is the number of samples
  """
  return tf.reduce_mean(x, axis=list(range(1, len(x.shape))))


def mean_absolute_error(y_true, y_pred):
  """ return the per sample mean absolute error """
  outputs = losses.mean_absolute_error(y_true=tf.expand_dims(y_true, axis=-1),
                                       y_pred=tf.expand_dims(y_pred, axis=-1))
  return per_sample_mean(outputs)


def mean_squared_error(y_true, y_pred):
  """ return the  per sample mean squared error """
  outputs = losses.mean_squared_error(y_true=tf.expand_dims(y_true, axis=-1),
                                      y_pred=tf.expand_dims(y_pred, axis=-1))
  return per_sample_mean(outputs)


def huber_loss(y_true, y_pred):
  """ return the per sample huber loss """
  outputs = tf.keras.losses.huber(y_true=tf.expand_dims(y_true, axis=-1),
                                  y_pred=tf.expand_dims(y_pred, axis=-1))
  return per_sample_mean(outputs)


def binary_cross_entropy(y_true, y_pred, from_logits: bool = False):
  """ return the per sample binary cross entropy """
  outputs = losses.binary_crossentropy(y_true=tf.expand_dims(y_true, axis=-1),
                                       y_pred=tf.expand_dims(y_pred, axis=-1),
                                       from_logits=from_logits)
  return per_sample_mean(outputs)


def get_loss_function(name: str):
  if name in ['mse', 'mean_squared_error']:
    return mean_squared_error
  elif name in ['mae', 'mean_absolute_error']:
    return mean_absolute_error
  elif name in ['huber', 'huber_loss']:
    return huber_loss
  elif name in ['bce', 'binary_cross_entropy', 'binary_crossentropy']:
    return binary_cross_entropy
  raise NotImplementedError(f'loss function {name} not found.')
