import numpy as np
import tensorflow as tf
import functools


def tf_random_choice(inputs, n_samples):
    """
    With replacement.
    Params:
      inputs (Tensor): Shape [n_states, n_features]
      n_samples (int): The number of random samples to take.
    Returns:
      sampled_inputs (Tensor): Shape [n_samples, n_features]
    """
    # (1, n_states) since multinomial requires 2D logits.
    uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0)

    ind = tf.random.categorical(uniform_log_prob, n_samples)
    ind = tf.squeeze(ind, 0, name="random_choice_ind")  # (n_samples,)

    return tf.gather(inputs, ind, name="random_choice")


def limit_bijector_to_dim(target_dim):
    def decorate(func):
        @functools.wraps(func)
        def wrapper_decorator(self, input):
            target_dim = 0
            X_target = input[..., target_dim:target_dim + 1]

            X_before = input[..., 0:target_dim]
            X_after = input[..., target_dim + 1:]

            return tf.concat((X_before, func(self, X_target), X_after), axis=-1)
            # return func(self, X_target)

        return wrapper_decorator

    def decorate_class(original_class):
        original_class._forward = decorate(original_class._forward)
        original_class._inverse = decorate(original_class._inverse)
        original_class._forward_log_det_jacobian = decorate(original_class._forward_log_det_jacobian)

        return original_class

    return decorate_class


def kl_divergence(y_true, y_pred):
    return np.mean(y_true - y_pred)


def mse(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)
