import tensorflow as tf
from math import pi


#@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32),
#                              tf.TensorSpec(shape=[None], dtype=tf.float32),
#                              tf.TensorSpec(shape=[None, None], dtype=tf.float32),
#                              tf.TensorSpec(shape=[None, None], dtype=tf.float32),
#                              tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
def gaussian_log_pdf(dim, mean, chol, inv_chol, x):
    constant_part = - 0.5 * dim * tf.math.log(2 * pi) - tf.reduce_sum(tf.math.log(tf.linalg.diag_part(chol)))
    return constant_part - 0.5 * tf.reduce_sum(tf.square(inv_chol @ tf.transpose(mean - x)), axis=0)

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.int32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[], dtype=tf.int32)
                                  ])
def sample_Gaussian(dim, mean, chol, num_samples):
    return tf.transpose(tf.expand_dims(mean, axis=-1)
                        + chol @ tf.random.normal((dim, num_samples), mean=0., stddev=1.))

def sample_diag_Gaussian(dim, mean, diag_part, num_samples):
    return tf.transpose(tf.expand_dims(mean, axis=-1)
                        + tf.expand_dims(diag_part, axis=-1) * tf.random.normal((dim, int(num_samples)), mean=0., stddev=1.))

def sample_diagonal_Gaussian(dim, len_mean, mean, cov_diag_entries, num_samples):
    return tf.expand_dims(mean, 1) + tf.math.sqrt(cov_diag_entries) * tf.random.normal((len_mean, num_samples, dim), mean=0., stddev=1.)

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32),
                              tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
def entropy_Gaussian(dim, chol):
    return 0.5 * dim * (tf.math.log(2 * pi) + 1) + tf.reduce_sum(tf.math.log(tf.linalg.tensor_diag_part(chol)))

def sample_categorical( num_samples, log_weights):
    thresholds = tf.expand_dims(tf.cumsum(tf.exp(log_weights)), 0)
  #  thresholds[0, -1] = 1.0
    eps = tf.random.uniform(shape=[num_samples, 1])
    samples = tf.argmax(eps < thresholds, axis=-1, output_type=tf.int32)
    return samples