import tensorflow as tf
import tensorflow_probability as tfp

class MMD:
    def __init__(self,groundtruth, alpha):
        self.groundtruth = tf.cast(tf.convert_to_tensor(groundtruth), tf.float32)
        self.num_groundtruth = len(groundtruth)
        self.alpha = alpha
        self.sigma, self.kernel = self.compute_sigma()
        self.ustat1 = self.estimate_kernel_on_gt()

    @tf.function
    def compute_sigma(self, max_points_for_median=1000):
        max_points = tf.cast(tf.math.minimum(max_points_for_median, len(self.groundtruth)), tf.float32)
      #  distances = tf.zeros((int(0.5 * (max_points * max_points + max_points)), len(self.groundtruth[0])))
        distances = tf.TensorArray(tf.float32, size=int(0.5 * (max_points * max_points + max_points)))
        index = 0
        for i in range(int(max_points)):
            for j in range(i, int(max_points)):
                distances = distances.write(index, tf.math.square(self.groundtruth[i] - self.groundtruth[j]))
               # distances[index] = tf.math.square(self.groundtruth[i] - self.groundtruth[j])
                index += 1
        sigma = tf.linalg.diag(tfp.stats.percentile(distances.stack(), 50, axis=0))
        kernel = tf.linalg.inv(self.alpha * sigma)
        return sigma, kernel

#    @tf.function
    def estimate_kernel_on_gt(self):
        ustat1 = 0.
        for i in range(self.groundtruth.shape[0]):
            diff = self.groundtruth[i] - self.groundtruth
            ustat1 += tf.reduce_sum(tf.exp(-tf.reduce_sum(diff @ self.kernel * diff, axis=1)))
        return ustat1

 #   @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32),
 #                                 tf.TensorSpec(shape=[], dtype=tf.int32)
 #                                 ])
    def estimate_kernel_on_model(self, sample, num_model_samples):
        ustat2 = 0.
        for i in range(num_model_samples):
            diff = sample[i] - sample
            ustat2 += tf.reduce_sum(tf.exp(-tf.reduce_sum(diff @ self.kernel * diff, axis=1)))
        return ustat2

  #  @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def kernel_mix(self, sample):
        ustat3 = 0.
        for i in range(self.num_groundtruth):
            diff = self.groundtruth[i] - sample
            ustat3 += tf.reduce_sum(tf.exp(-tf.reduce_sum(diff @ self.kernel * diff, axis=1)))
        return ustat3
#
    def compute_MMD(self, model_sample):
        self.sample_1 = tf.convert_to_tensor(model_sample)
        num_1 = self.num_groundtruth
        num_2 = len(model_sample)
        MMD = self.ustat1/(num_1**2) \
              + self.estimate_kernel_on_model(model_sample, len(model_sample))/(num_2**2) \
              - 2*self.kernel_mix(model_sample)/(num_1*num_2)
        return MMD
