import numpy as np
from util.distributions import sample_Gaussian, sample_categorical
from util.distributions import entropy_Gaussian as Gaussian_entropy
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
from math import pi

class GMM:
    def __init__(self, weights, means, covars):
        self.num_dimensions = len(means[0])
        self._const_log_det = tf.constant(0.5 * self.num_dimensions * tf.math.log(2*pi))
        self.log_weights = tf.Variable(tf.math.log(tf.convert_to_tensor(weights.astype(np.float32))), shape=[None])
        self.covars = tf.Variable(tf.convert_to_tensor(covars), shape=[None, self.num_dimensions, self.num_dimensions])
        self.chol_covar = tf.Variable(tf.stack([tf.linalg.cholesky(cov) for cov in tf.convert_to_tensor(covars)]), shape=[None, self.num_dimensions, self.num_dimensions])
        self.means = tf.Variable(tf.convert_to_tensor(means), shape=[None, self.num_dimensions])
        self.l2_regularizers = tf.Variable(tf.ones(len(weights)), shape=[None])
        self.replace_weights(self.log_weights)

    @property
    def weights(self):
        return tf.math.exp(self.log_weights)

    def replace_weights(self, new_logweights):
        self.log_weights.assign(new_logweights - tf.reduce_logsumexp(self.log_weights))

    def sample_from_components(self, samples_per_component):
        samples = []
        for i in range(self.num_components):
            this_samples = sample_Gaussian(self.num_dimensions, self.means[i], self.chol_covar[i], samples_per_component[i])
            samples.append(this_samples)
        return samples

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def _component_log_densities(self, samples):
        diffs = tf.expand_dims(samples, 0) - tf.expand_dims(self.means, 1)
        sqrts = tf.linalg.triangular_solve(self.chol_covar, tf.transpose(diffs, [0, 2, 1]))
        mahalas = - 0.5 * tf.reduce_sum(sqrts * sqrts, axis=1)
        const_parts = - 0.5 * tf.reduce_sum(tf.math.log(tf.square(tf.linalg.diag_part(self.chol_covar))), axis=1) \
                      - 0.5 * self.num_dimensions * tf.math.log(2 * pi)
        return mahalas + tf.expand_dims(const_parts, axis=1)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def _component_log_densities_old(self, samples):
        log_pdfs = tf.TensorArray(tf.float32, size=self.num_components)
        for i in range(self.num_components):
            diff = tf.transpose(samples - self.means[i])
            sqrt = tf.linalg.triangular_solve(self.chol_covar[i], diff)
            log_pdfs = log_pdfs.write(i, - 0.5 * tf.reduce_sum(sqrt * sqrt, axis=0))
        log_pdfs = log_pdfs.stack() \
                   - 0.5 * tf.expand_dims(tf.reduce_sum(tf.math.log(tf.square(tf.linalg.diag_part(self.chol_covar))), axis=1), axis=1) \
                   - self._const_log_det
        return log_pdfs

    def component_log_densities(self, samples):
        return self._component_log_densities(samples)

    def log_density_tf(self, samples):
        log_densities = self._component_log_densities(samples)
        weighted_densities = log_densities + tf.expand_dims(self.log_weights, axis=1)
        return tf.reduce_logsumexp(weighted_densities, axis=0)

    def log_density(self, samples):
        return self.log_density_tf(samples).numpy()

    def density(self, samples):
        return tf.exp(self.log_density_tf(samples)).numpy()

    @tf.function()
    def _get_average_entropy(self):
        avg_entropy = 0.
        for i in range(self.num_components):
            avg_entropy += tf.exp(self.log_weights[i]) * Gaussian_entropy(self.num_dimensions, self.chol_covar[i])
        return avg_entropy

    def get_average_entropy(self):
        return self._get_average_entropy()

    def replace_components(self, new_means, new_covs, new_chols):
        new_means = tf.stack(new_means, axis=0)
        new_covs = tf.stack(new_covs, axis=0)
        new_chols = tf.stack(new_chols, axis=0)
        self.means.assign(new_means)
        self.covars.assign(new_covs)
        self.chol_covar.assign(new_chols)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def log_density_and_grad_tf(self, samples):
        with tf.GradientTape() as gfg:
            gfg.watch(samples)
            log_densities = self._component_log_densities(samples)
            weighted_densities = log_densities + tf.expand_dims(self.log_weights, axis=1)
            target = tf.reduce_logsumexp(weighted_densities, axis=0)
        target_grad = gfg.gradient(target, samples)
        return target, target_grad

    def log_density_and_grad(self, samples):
        return self.log_density_and_grad_tf(samples)

    def sample(self, num_samples):
        sampled_components = sample_categorical(num_samples=num_samples, log_weights=self.log_weights)
        samples = []
        for i in range(self.num_components):
            n_samples = tf.reduce_sum(tf.cast(sampled_components == i, tf.int32))
            this_samples = sample_Gaussian(self.num_dimensions, self.means[i], self.chol_covar[i], n_samples)
            samples.append(this_samples)
        return tf.random.shuffle(tf.concat(samples, axis=0))

    @property
    def num_components(self):
        return tf.shape(self.log_weights)[0]

    def add_component(self, initial_weight, initial_mean, initial_covar):
        self.means.assign(tf.concat((self.means,  tf.expand_dims(initial_mean, axis=0)), axis=0))
        self.covars.assign(tf.concat((self.covars,  tf.expand_dims(initial_covar, axis=0)), axis=0))
        self.chol_covar.assign(tf.concat((self.chol_covar,  tf.expand_dims(tf.linalg.cholesky(initial_covar), axis=0)), axis=0))
        self.replace_weights(tf.concat((self.log_weights, tf.expand_dims(tf.math.log(initial_weight), axis=0)), axis=0))
        self.l2_regularizers.assign(tf.concat((self.l2_regularizers, tf.ones(1)), axis=0))

    def remove_component(self, idx):
        self.replace_weights(tf.concat((self.log_weights[:idx], self.log_weights[idx+1:]), axis=0))
        self.means.assign(tf.concat((self.means[:idx], self.means[idx+1:]), axis=0))
        self.covars.assign(tf.concat((self.covars[:idx], self.covars[idx+1:]), axis=0))
        self.chol_covar.assign(tf.concat((self.chol_covar[:idx], self.chol_covar[idx+1:]), axis=0))
        self.l2_regularizers.assign(tf.concat((self.l2_regularizers[:idx], self.l2_regularizers[idx+1:]), axis=0))

