Source code for gmmvi.experiments.evaluation.mmd

import tensorflow as tf
import tensorflow_probability as tfp

[docs]class MMD: """This class can be used for computing the Maximum Mean Discrepancy :cite:p:`Gretton2012`. The MMD can be used to compute the discrepancy between a model sample and a groundtruth sample. Note that instantiating this object can be quite slow, but computing the MMD using :py:meth:`compute_MMD<gmmvi.experiments.evaluation.mmd.MMD.compute_MMD>` should be fast. Parameters: groundtruth: tf.Tensor The groundtruth sample of shape number_of_samples x dimension alpha: tf.float32 A factor for scaling the diagonal bandwidth matrix (which is automatically chosen based on the groundtruth sample using the Median trick :cite:p:`Gretton2012`). """ def __init__(self, groundtruth, alpha): self.groundtruth = tf.cast(tf.convert_to_tensor(groundtruth), tf.float32) self.num_groundtruth = len(groundtruth) self.sigma = self.compute_sigma() self.set_alpha(alpha) @tf.function def compute_sigma(self, max_points_for_median=1000): max_points = tf.cast(tf.math.minimum(max_points_for_median, len(self.groundtruth)), tf.float32) distances = tf.TensorArray(tf.float32, size=int(0.5 * (max_points * max_points + max_points))) index = 0 for i in range(int(max_points)): for j in range(i, int(max_points)): distances = distances.write(index, tf.math.square(self.groundtruth[i] - self.groundtruth[j])) index += 1 sigma = tf.linalg.diag(tfp.stats.percentile(distances.stack(), 50, axis=0)) return sigma @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32), tf.TensorSpec(shape=[], dtype=tf.float32)]) def compute_ustat(self, sample, alpha): kernel = tf.linalg.inv(alpha * self.sigma) ustat2 = 0. for i in tf.range(tf.shape(sample)[0]): diff = sample[i] - sample ustat2 += tf.reduce_sum(tf.exp(-tf.reduce_sum(diff @ kernel * diff, axis=1))) return ustat2 @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32), tf.TensorSpec(shape=[], dtype=tf.float32)]) def kernel_mix(self, sample, alpha): kernel = tf.linalg.inv(alpha * self.sigma) ustat3 = 0. for i in range(self.num_groundtruth): diff = self.groundtruth[i] - sample ustat3 += tf.reduce_sum(tf.exp(-tf.reduce_sum(diff @ kernel * diff, axis=1))) return ustat3 def set_alpha(self, alpha): self._alpha = alpha self.ustat1 = self.compute_ustat(self.groundtruth, alpha)
[docs] def compute_MMD(self, model_sample): """ Compute the MMD between the model_sample and the groundtruth data that was provided when instantiating this object. Parameters: model_sample: tf.Tensor The sample from the model of shape number_of_samples x dimension Returns: float: The MMD between model sample and groundtruth sample """ num_1 = self.num_groundtruth num_2 = len(model_sample) MMD = self.ustat1/(num_1**2) \ + self.compute_ustat(model_sample, self._alpha)/(num_2**2) \ - 2*self.kernel_mix(model_sample, self._alpha)/(num_1*num_2) return MMD