import tensorflow as tf
from math import pi

class SampleDB:
    def __init__(self, dim):
        self._dim = dim
        self.samples = tf.zeros((0, dim))
        self.means = tf.zeros((0, dim))
        self.chols = tf.zeros((0, dim, dim))
        self.inv_chols = tf.zeros((0, dim, dim))
        self.target_lnpdfs = tf.zeros(0)
        self.target_grads = tf.zeros((0, dim))
        self.mapping = tf.zeros(0, dtype=tf.int32)

    def add_samples(self, samples, means, chols, target_lnpdfs, target_grads, mapping):
        self.mapping = tf.concat((self.mapping, mapping + int(len(self.chols))), axis=0)
        self.means = tf.concat((self.means, means), axis=0)
        self.chols = tf.concat((self.chols, chols), axis=0)
        self.inv_chols = tf.concat((self.inv_chols, tf.linalg.inv(chols)), axis=0)
        self.samples = tf.concat((self.samples, samples), axis=0)
        self.target_lnpdfs = tf.concat((self.target_lnpdfs, target_lnpdfs), axis=0)
        self.target_grads = tf.concat((self.target_grads, target_grads), axis=0)


    def get_random_sample(self, N):
        chosen_indices = tf.random.shuffle(tf.range(len(self.samples)))[:N]
        return tf.gather(self.samples, chosen_indices), tf.gather(self.target_lnpdfs, chosen_indices)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def gaussian_log_pdf(self, mean, chol, inv_chol, x):
        constant_part = - 0.5 * self._dim * tf.math.log(2*pi) - tf.reduce_sum(tf.math.log(tf.linalg.diag_part(chol)))
        return constant_part - 0.5 * tf.reduce_sum(tf.square(inv_chol @ tf.transpose(mean - x)), axis=0)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def evaluate_background(self, weights, means, chols, inv_chols, samples):
        log_weights = tf.math.log(weights)
        log_pdfs = self.gaussian_log_pdf(means[0], chols[0], inv_chols[0], samples) + log_weights[0]

        for i in range(1, len(weights)):
            log_pdfs = tf.reduce_logsumexp(tf.stack((
                log_pdfs,
                self.gaussian_log_pdf(means[i], chols[i], inv_chols[i], samples) + log_weights[i]
            ), axis=0), axis=0)
        return log_pdfs

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.int32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None], dtype=tf.int32),
                                  tf.TensorSpec(shape=[None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),
                                  tf.TensorSpec(shape=[None, None, None], dtype=tf.float32)])
    def _get_newest_samples(self, num, samples, mapping, target_lnpdfs, target_grads, means, chols, inv_chols):
        active_sample_index = tf.maximum(0, len(samples) - num)
        active_sample = samples[active_sample_index:]
        active_mapping = mapping[active_sample_index:]
        active_target_lnpdfs = target_lnpdfs[active_sample_index:]
        active_target_grads = target_grads[active_sample_index:]

        #calculate gmm
        active_components, _, count = tf.unique_with_counts(active_mapping)
        means = tf.gather(means, active_components)
        chols = tf.gather(chols, active_components)
        inv_chols = tf.gather(inv_chols, active_components)
        count = tf.cast(count, tf.float32)
        weight = count / tf.reduce_sum(count)

        log_pdfs = self.evaluate_background(weight, means, chols, inv_chols, active_sample)
        return log_pdfs, active_sample, active_target_lnpdfs, active_target_grads

    def get_newest_samples(self, num):
        if self.samples.shape[0] == 0 or num == 0:
            return tf.zeros(0), tf.zeros((0,self._dim)), tf.zeros(0), tf.zeros((0, self._dim))
        else:
            return self._get_newest_samples(num, self.samples, self.mapping, self.target_lnpdfs,
                                            self.target_grads, self.means, self.chols, self.inv_chols)

