import tensorflow as tf
import tensorflow_compression as tfc
from tensorflow_compression.python.ops import math_ops
from tensorflow_compression.python.distributions import helpers
import functools

# from configs import log_prob_lowerbound
# The minimum value log probabilities are allowed to be.
# -29.8 is ≈ (a bit higher than) log2(1e-9)
log_prob_lowerbound = -29.8


class MyContinuousBatchedEntropyModel(tfc.ContinuousBatchedEntropyModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @tf.Module.with_name_scope
    def __call__(self, bottleneck, training=True, log_prob_lowerbound=None):
        """Perturbs a tensor with (quantization) noise and estimates rate.
        Same as parent class method, except the log_probs under the prior is thresholded
        elementwise to be at least log_probs_lowerbound, before being aggregated into
        the rate loss; this only happens during training.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
            least `self.coding_rank` dimensions, and the innermost dimensions must
            be broadcastable to `self.prior_shape`.
          training: Boolean. If `False`, computes the Shannon information of
            `bottleneck` under the distribution `self.prior`, which is a
            non-differentiable, tight *lower* bound on the number of bits needed to
            compress `bottleneck` using `compress()`. If `True`, returns a somewhat
            looser, but differentiable *upper* bound on this quantity.
          log_prob_lowerbound: The minimum value log_probs are clipped at; use None
          to disable clipping. Clipping is skipped if not training.

        Returns:
          A tuple (bottleneck_perturbed, bits) where `bottleneck_perturbed` is
          `bottleneck` perturbed with (quantization) noise, and `bits` is the rate.
          `bits` has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """
        log_prob_fn = functools.partial(self._log_prob_from_prior, self.prior)
        if training:
            log_probs, bottleneck_perturbed = math_ops.perturb_and_apply(
                log_prob_fn, bottleneck, expected_grads=self.expected_grads)
        else:
            bottleneck_perturbed = self.quantize(bottleneck)
            log_probs = log_prob_fn(bottleneck_perturbed)

        if training and log_prob_lowerbound is not None:
            log_probs = math_ops.lower_bound(log_probs, log_prob_lowerbound)

        axes = tuple(range(-self.coding_rank, 0))
        bits = tf.reduce_sum(log_probs, axis=axes) / (
            -tf.math.log(tf.constant(2, dtype=log_probs.dtype)))
        return bottleneck_perturbed, bits


class MyContinuousIndexedEntropyModel(tfc.ContinuousIndexedEntropyModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @tf.Module.with_name_scope
    def __call__(self, bottleneck, indexes, training=True, log_prob_lowerbound=None):
        """Perturbs a tensor with (quantization) noise and estimates rate.
        Same as parent class method, except the log_probs under the prior is thresholded
             elementwise to be at least log_probs_lowerbound, before being aggregated into
             the rate loss; this only happens during training.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed.
          indexes: `tf.Tensor` specifying the scalar distribution for each element
            in `bottleneck`. See class docstring for examples.
          training: Boolean. If `False`, computes the Shannon information of
            `bottleneck` under the distribution computed by `self.prior_fn`,
            which is a non-differentiable, tight *lower* bound on the number of bits
            needed to compress `bottleneck` using `compress()`. If `True`, returns a
            somewhat looser, but differentiable *upper* bound on this quantity.
    
        Returns:
          A tuple (bottleneck_perturbed, bits) where `bottleneck_perturbed` is
          `bottleneck` perturbed with (quantization) noise and `bits` is the rate.
          `bits` has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """

        indexes = self._normalize_indexes(indexes)
        prior = self._make_prior(indexes)
        if training:
            bottleneck_perturbed = bottleneck + tf.random.uniform(
                tf.shape(bottleneck), minval=-.5, maxval=.5, dtype=bottleneck.dtype)

            def log_prob_fn(bottleneck_perturbed, indexes):
                # When using expected_grads=True, we will use a tf.custom_gradient on
                # this function. In this case, all non-Variable tensors that determine
                # the result of this function need to be declared explicitly, i.e we
                # need `indexes` to be a declared argument and `prior` instantiated
                # here. If we would instantiate it outside this function declaration and
                # reference here via a closure, we would get a `None` gradient for
                # `indexes`.
                prior = self._make_prior(indexes)
                return self._log_prob_from_prior(prior, bottleneck_perturbed)

            log_probs, bottleneck_perturbed = math_ops.perturb_and_apply(
                log_prob_fn, bottleneck, indexes, expected_grads=self._expected_grads)
        else:
            offset = helpers.quantization_offset(prior)
            bottleneck_perturbed = self._quantize(bottleneck, offset)
            log_probs = self._log_prob_from_prior(prior, bottleneck_perturbed)

        # only difference to parent class is here:
        if training and log_prob_lowerbound is not None:
            log_probs = math_ops.lower_bound(log_probs, log_prob_lowerbound)

        axes = tuple(range(-self.coding_rank, 0))
        bits = tf.reduce_sum(log_probs, axis=axes) / (
            -tf.math.log(tf.constant(2, dtype=log_probs.dtype)))
        return bottleneck_perturbed, bits

    @tf.Module.with_name_scope
    def get_prior_dist(self, indexes):
        """Make a tfp distribution object of the indexed prior, base on prior_fn and the indexes (see
        ContinuousIndexedEntropyModel._make_prior in continuous_indexed.py).
        Args:
          indexes: `tf.Tensor` specifying the scalar distribution for each element
            in `bottleneck`. See class docstring for examples.

        Returns:
          `log_probs` = prior.log_prob(bottleneck), where the prior is constructed with
          the mechanism of an indexed entropy model;
          has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """

        indexes = self._normalize_indexes(indexes)
        prior = self._make_prior(indexes)
        return prior

    @tf.Module.with_name_scope
    def eval_log_prior(self, bottleneck, indexes, log_prob_lowerbound=None):
        """Compute the log density of the `bottleneck` tensor under the prior.
        This basically implements the second part of __call__ (the first part is adding quantization noise,
        and the second part is evaluating bitrate under the prior), but returns log (e) density instead of bits.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed.
          indexes: `tf.Tensor` specifying the scalar distribution for each element
            in `bottleneck`. See class docstring for examples.

        Returns:
          `log_probs` = prior.log_prob(bottleneck), where the prior is constructed with
          the mechanism of an indexed entropy model;
          has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """

        prior = self.get_prior_dist(indexes)
        # if training:
        #     bottleneck_perturbed = bottleneck + tf.random.uniform(
        #         tf.shape(bottleneck), minval=-.5, maxval=.5, dtype=bottleneck.dtype)
        #
        #     def log_prob_fn(bottleneck_perturbed, indexes):
        #         # When using expected_grads=True, we will use a tf.custom_gradient on
        #         # this function. In this case, all non-Variable tensors that determine
        #         # the result of this function need to be declared explicitly, i.e we
        #         # need `indexes` to be a declared argument and `prior` instantiated
        #         # here. If we would instantiate it outside this function declaration and
        #         # reference here via a closure, we would get a `None` gradient for
        #         # `indexes`.
        #         prior = self._make_prior(indexes)
        #         return self._log_prob_from_prior(prior, bottleneck_perturbed)
        #
        #     log_probs, bottleneck_perturbed = math_ops.perturb_and_apply(
        #         log_prob_fn, bottleneck, indexes, expected_grads=self._expected_grads)
        # else:
        #     offset = helpers.quantization_offset(prior)
        #     bottleneck_perturbed = self._quantize(bottleneck, offset)
        #     log_probs = self._log_prob_from_prior(prior, bottleneck_perturbed)

        log_probs = prior.log_prob(bottleneck)

        # main difference to parent class is here:
        if log_prob_lowerbound is not None:
            log_probs = math_ops.lower_bound(log_probs, log_prob_lowerbound)

        axes = tuple(range(-self.coding_rank, 0))
        # bits = tf.reduce_sum(log_probs, axis=axes) / (
        #     -tf.math.log(tf.constant(2, dtype=log_probs.dtype)))
        log_probs = tf.reduce_sum(log_probs, axis=axes)
        return log_probs


# not gonna bother with multiple inheritance or calling MyContinuousIndexedEntropyModel.__call__ from within.
# class MyLocationScaleIndexedEntropyModel(MyContinuousIndexedEntropyModel, tfc.LocationScaleIndexedEntropyModel):

class MyLocationScaleIndexedEntropyModel(MyContinuousIndexedEntropyModel):
    """
    A copy of tfc.LocationScaleIndexedEntropyModel, except it inherits from
    MyContinuousIndexedEntropyModel instead of tfc.ContinuousIndexedEntropyModel
    to allow for lower bounding log_probs in the __call__ method.
    The class body is literally identical to tfc.LocationScaleIndexedEntropyModel except
    the addition of log_prob_lowerbound param and passing it to super().__call__ in __call__.

    Indexed entropy model for location-scale family of random variables.

    This class is a common special case of `ContinuousIndexedEntropyModel`. The
    specified distribution is parameterized with `num_scales` values of scale
    parameters. An element-wise location parameter is handled by shifting the
    distributions to zero. Note: this only works for shift-invariant
    distributions, where the `loc` parameter really denotes a translation (i.e.,
    not for the log-normal distribution).
    """

    def __init__(self,
                 prior_fn,
                 num_scales,
                 scale_fn,
                 coding_rank,
                 compression=False,
                 stateless=False,
                 expected_grads=False,
                 tail_mass=2 ** -8,
                 range_coder_precision=12,
                 dtype=tf.float32,
                 laplace_tail_mass=0):
        """Initializes the instance.

        Args:
          prior_fn: A callable returning a `tfp.distributions.Distribution` object,
            typically a `Distribution` class or factory function. This is a density
            model fitting the marginal distribution of the bottleneck data with
            additive uniform noise, which is shared a priori between the sender and
            the receiver. For best results, the distributions should be flexible
            enough to have a unit-width uniform distribution as a special case,
            since this is the marginal distribution for bottleneck dimensions that
            are constant. The callable will receive keyword arguments as determined
            by `parameter_fns`.
          num_scales: Integer. Values in `indexes` must be in the range
            `[0, num_scales)`.
          scale_fn: Callable. `indexes` is passed to the callable, and the return
            value is given as `scale` keyword argument to `prior_fn`.
          coding_rank: Integer. Number of innermost dimensions considered a coding
            unit. Each coding unit is compressed to its own bit string, and the
            bits in the `__call__` method are summed over each coding unit.
          compression: Boolean. If set to `True`, the range coding tables used by
            `compress()` and `decompress()` will be built on instantiation. If set
            to `False`, these two methods will not be accessible.
          stateless: Boolean. If `False`, range coding tables are created as
            `Variable`s. This allows the entropy model to be serialized using the
            `SavedModel` protocol, so that both the encoder and the decoder use
            identical tables when loading the stored model. If `True`, creates range
            coding tables as `Tensor`s. This makes the entropy model stateless and
            allows it to be constructed within a `tf.function` body, for when the
            range coding tables are provided manually. If `compression=False`, then
            `stateless=True` is implied and the provided value is ignored.
          expected_grads: If True, will use analytical expected gradients during
            backpropagation w.r.t. additive uniform noise.
          tail_mass: Float. Approximate probability mass which is range encoded with
            less precision, by using a Golomb-like code.
          range_coder_precision: Integer. Precision passed to the range coding op.
          dtype: `tf.dtypes.DType`. The data type of all floating-point
            computations carried out in this class.
          laplace_tail_mass: Float. If positive, will augment the prior with a
            laplace mixture for training stability. (experimental)
        """
        num_scales = int(num_scales)
        super().__init__(
            prior_fn=prior_fn,
            index_ranges=(num_scales,),
            parameter_fns=dict(
                loc=lambda _: 0.,
                scale=scale_fn,
            ),
            coding_rank=coding_rank,
            channel_axis=None,
            compression=compression,
            stateless=stateless,
            expected_grads=expected_grads,
            tail_mass=tail_mass,
            range_coder_precision=range_coder_precision,
            dtype=dtype,
            laplace_tail_mass=laplace_tail_mass,
        )

    @tf.Module.with_name_scope
    def get_prior_dist(self, scale_indexes, loc=None):
        """Create an indexed prior distribution (a tfp distribution object), based on prior_fn and the indexes (see
        ContinuousIndexedEntropyModel._make_prior in continuous_indexed.py).

        Args:
          scale_indexes: `tf.Tensor` indexing the scale parameter for each element
            in `bottleneck`. Must have the same shape as `bottleneck`.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            elements is assumed to be zero. Otherwise, specifies the location
            parameter for each element in `bottleneck`. Must have the same shape as
            `bottleneck`.

        Returns:
          A tuple (bottleneck_perturbed, bits) where `bottleneck_perturbed` is
          `bottleneck` perturbed with (quantization) noise and `bits` is the rate.
          `bits` has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """
        indexes = self._normalize_indexes(scale_indexes)
        # copied from self._make_prior
        parameters = {k: f(indexes) for k, f in self.parameter_fns.items()}
        # The current IndexedEntropyModel object (self) is constructed to be location independent, and only adds or
        # subtracts the loc offset for evaluating the prior or compression, so no prior distribution object is ever
        # created. To create the desired prior dist, I'll need to supply the given loc, overriding the default 0 as
        # the result of self._parameter_fns['loc'] == lambda _: 0 (see the constructor).
        if loc is not None:
            parameters['loc'] = loc
        prior = self.prior_fn(**parameters)
        return prior

    @tf.Module.with_name_scope
    def __call__(self, bottleneck, scale_indexes, loc=None, training=True, log_prob_lowerbound=None):
        """Perturbs a tensor with (quantization) noise and estimates rate.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed.
          scale_indexes: `tf.Tensor` indexing the scale parameter for each element
            in `bottleneck`. Must have the same shape as `bottleneck`.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            elements is assumed to be zero. Otherwise, specifies the location
            parameter for each element in `bottleneck`. Must have the same shape as
            `bottleneck`.
          training: Boolean. If `False`, computes the Shannon information of
            `bottleneck` under the distribution computed by `self.prior_fn`,
            which is a non-differentiable, tight *lower* bound on the number of bits
            needed to compress `bottleneck` using `compress()`. If `True`, returns a
            somewhat looser, but differentiable *upper* bound on this quantity.

        Returns:
          A tuple (bottleneck_perturbed, bits) where `bottleneck_perturbed` is
          `bottleneck` perturbed with (quantization) noise and `bits` is the rate.
          `bits` has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """
        if loc is None:
            loc = 0.0
        bottleneck_centered = bottleneck - loc
        bottleneck_centered_perturbed, bits = super().__call__(
            bottleneck_centered, scale_indexes, training=training, log_prob_lowerbound=log_prob_lowerbound)
        bottleneck_perturbed = bottleneck_centered_perturbed + loc
        return bottleneck_perturbed, bits

    @tf.Module.with_name_scope
    def eval_log_prior(self, bottleneck, scale_indexes, loc=None, log_prob_lowerbound=None):
        """Compute the log density of the `bottleneck` tensor under the prior.
        This basically implements the second part of __call__ (the first part is adding quantization noise,
        and the second part is evaluating bitrate under the prior), but returns log (e) density instead of bits.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed.
          scale_indexes: `tf.Tensor` indexing the scale parameter for each element
            in `bottleneck`. Must have the same shape as `bottleneck`.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            elements is assumed to be zero. Otherwise, specifies the location
            parameter for each element in `bottleneck`. Must have the same shape as
            `bottleneck`.

        Returns:
          `log_probs` = prior.log_prob(bottleneck), where the prior is constructed with
          the mechanism of an indexed entropy model;
          has the same shape as `bottleneck` without the `self.coding_rank`
          innermost dimensions.
        """
        if loc is None:
            loc = 0.0
        bottleneck_centered = bottleneck - loc
        log_prior = super().eval_log_prior(bottleneck_centered, scale_indexes, log_prob_lowerbound=log_prob_lowerbound)
        return log_prior

    @tf.Module.with_name_scope
    def quantize(self, bottleneck, scale_indexes, loc=None):
        """Quantizes a floating-point tensor.

        To use this entropy model as an information bottleneck during training, pass
        a tensor through this function. The tensor is rounded to integer values
        modulo a quantization offset, which depends on `indexes`. For instance, for
        Gaussian distributions, the returned values are rounded to the location of
        the mode of the distributions plus or minus an integer.

        The gradient of this rounding operation is overridden with the identity
        (straight-through gradient estimator).

        Args:
          bottleneck: `tf.Tensor` containing the data to be quantized.
          scale_indexes: `tf.Tensor` indexing the scale parameter for each element
            in `bottleneck`. Must have the same shape as `bottleneck`.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            elements is assumed to be zero. Otherwise, specifies the location
            parameter for each element in `bottleneck`. Must have the same shape as
            `bottleneck`.

        Returns:
          A `tf.Tensor` containing the quantized values.
        """
        if loc is None:
            return super().quantize(bottleneck, scale_indexes)
        else:
            return super().quantize(bottleneck - loc, scale_indexes) + loc

    @tf.Module.with_name_scope
    def compress(self, bottleneck, scale_indexes, loc=None):
        """Compresses a floating-point tensor.

        Compresses the tensor to bit strings. `bottleneck` is first quantized
        as in `quantize()`, and then compressed using the probability tables derived
        from `indexes`. The quantized tensor can later be recovered by calling
        `decompress()`.

        The innermost `self.coding_rank` dimensions are treated as one coding unit,
        i.e. are compressed into one string each. Any additional dimensions to the
        left are treated as batch dimensions.

        Args:
          bottleneck: `tf.Tensor` containing the data to be compressed.
          scale_indexes: `tf.Tensor` indexing the scale parameter for each element
            in `bottleneck`. Must have the same shape as `bottleneck`.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            elements is assumed to be zero. Otherwise, specifies the location
            parameter for each element in `bottleneck`. Must have the same shape as
            `bottleneck`.

        Returns:
          A `tf.Tensor` having the same shape as `bottleneck` without the
          `self.coding_rank` innermost dimensions, containing a string for each
          coding unit.
        """
        if loc is not None:
            bottleneck -= loc
        return super().compress(bottleneck, scale_indexes)

    @tf.Module.with_name_scope
    def decompress(self, strings, scale_indexes, loc=None):
        """Decompresses a tensor.

        Reconstructs the quantized tensor from bit strings produced by `compress()`.

        Args:
          strings: `tf.Tensor` containing the compressed bit strings.
          scale_indexes: `tf.Tensor` indexing the scale parameter for each output
            element.
          loc: `None` or `tf.Tensor`. If `None`, the location parameter for all
            output elements is assumed to be zero. Otherwise, specifies the location
            parameter for each output element. Must have the same shape as
            `scale_indexes`.

        Returns:
          A `tf.Tensor` of the same shape as `scale_indexes`.
        """
        values = super().decompress(strings, scale_indexes)
        if loc is not None:
            values += loc
        return values


class MyDeepFactorized(tfc.DeepFactorized):
    """
    A copy of tfc.DeepFactorized, except I added a implementation for sampling.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # def _quantile(self, u):
    #     """
    #     Compute the inverse of the CDF, given a float tensor between 0 and 1.
    #     :param u:
    #     :return:
    #     """

    def _sample_n(self, n, seed=None, max_its=1000, tol=1e-6, verbose=False):
        """
        Do inverse-CDF sampling to draw n samples from the batch of distributions; result will have shape n by batch_shape.
        Boilerplate code borrowed from the Cauchy example:
        https://github.com/tensorflow/probability/blob/f3777158691787d3658b5e80883fe1a933d48989/tensorflow_probability/python/distributions/cauchy.py#L166
        The tfp.distribution base class will implement the sample() method based on _sample_n() automatically,
        see https://github.com/tensorflow/probability/blob/v0.15.0/tensorflow_probability/python/distributions/distribution.py#L1218-L1234
        :param n:
        :param seed:
        :return:
        """
        from tensorflow_probability.python.internal import prefer_static as ps
        from tensorflow_probability.python.internal import samplers

        # batch_shape = self._batch_shape_tensor()
        batch_shape = self.batch_shape_tensor()
        shape = ps.concat([[n], batch_shape], 0)
        # Clip the min and max value of the uniform(0,1) sample at numerical precision, to avoid
        # getting exactly 0.0 or 1.0, which can cause infinite loop in the code below
        import numpy as np
        eps = np.finfo(self.dtype).tiny
        eps = tf.cast(eps, self.dtype)
        u = samplers.uniform(
            shape=shape, minval=eps, maxval=1. - eps, dtype=self.dtype, seed=seed)

        # return self._quantile(u)

        # Do bisection search to compute z such that CDF(z) = u. Equivalently we want CDF(z) - u = 0, or
        # sigmoid(logit(F(z))) - sigmoid(logit(u)) = 0, or logit(F(z)) - logit(u) = 0.
        # So we need to find the root of the function f(z) := logit(F(z)) - logit(u)
        logit_u = tf.math.log(u) - tf.math.log(1 - u)

        def f(z):  # this is vectorized; n by batch_shape -> n by batch_shape
            return self._logits_cumulative(z) - logit_u

        float_type = self.dtype
        # Simple heuristic initialization, with [a, b] = [-1, 1]. Keep expanding the interval until f(a)*f(b)<=0;
        # here we use the additional fact that f is increasing, so we just need f(a)<=0 and f(b)>=0
        a = tf.ones_like(u, dtype=float_type) * -1.
        count = 0
        while True:
            if count > 100:
                raise RuntimeError('Stuck in infinite loop when setting the left endpoint for bisection search')
            f_a = f(a)
            good = f_a <= 0
            if tf.reduce_all(good):
                break
            good = tf.cast(good, float_type)
            a = a * good + 2 * a * (1 - good)
            count += 1

        count = 0
        b = tf.ones_like(u, dtype=float_type) * 1.
        while True:
            if count > 100:
                raise RuntimeError('Stuck in infinite loop when setting the right endpoint for bisection search')
            f_b = f(b)
            good = f_b >= 0
            if tf.reduce_all(good):
                break
            good = tf.cast(good, float_type)
            b = b * good + 2 * b * (1 - good)
            count += 1

        prev_err = float('inf')
        for t in range(max_its):
            mid = 0.5 * (a + b)
            f_mid = f(mid)
            pos = f_mid > 0
            non_pos = tf.logical_not(pos)
            neg = f_mid < 0
            non_neg = tf.logical_not(neg)
            # pos, non_pos, neg, non_neg = map(lambda x: tf.cast(x, float_type), (pos, non_pos, neg, non_neg))
            # a := mid for the "neg" coordinates, b := mid for the "pos" coordinates; otherwise the endpoints remain unchanged
            a = a * tf.cast(non_neg, float_type) + mid * tf.cast(neg, float_type)
            b = b * tf.cast(non_pos, float_type) + mid * tf.cast(pos, float_type)

            err = tf.reduce_mean(b - a)
            # if t % 20 == 0:
            #     print(diff)

            if tf.reduce_all(tf.logical_and(non_pos, non_neg)) or err <= tol:  # success
                break

            if err == prev_err:
                if verbose:
                    print(f'bisection terminated after {t} its with err {err}')
                    print('no more improvement possible given the finite numerical accuracy')
                break
            else:
                prev_err = err
        else:
            if verbose:
                print(f'bisection max its ({max_its}) reached; terminated with err {err}')

        return mid


class MyUniformNoiseAdapter(tfc.UniformNoiseAdapter):
    """
    A copy of tfc.UniformNoiseAdaptor, except I augment the sampling method with the option of returning
    quantized observations, allowing to sample from the discretized distribution used for compression
    (instead of the "noisy" distribution with additive uniform noise used for training).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _sample_n(self, n, seed=None, quantized=False, **kwargs):
        with tf.name_scope("transform"):
            n = tf.convert_to_tensor(n, name="n")
            samples = self.base.sample(n, seed=seed, **kwargs)
            if quantized:
                offset = helpers.quantization_offset(self.base)
                samples = tf.round(samples - offset) + offset
            else:
                samples = samples + tf.random.uniform(tf.shape(samples), minval=-.5, maxval=.5, dtype=samples.dtype)

            return samples


class MyNoisyNormal(MyUniformNoiseAdapter):
    """Gaussian distribution with additive i.i.d. uniform noise. Same as the tfc version with tweaks to allow more
    flexible sampling."""

    def __init__(self, name="MyNoisyNormal", **kwargs):
        import tensorflow_probability as tfp
        super().__init__(tfp.distributions.Normal(**kwargs), name=name)


class MyNoisyDeepFactorized(MyUniformNoiseAdapter):
    """DeepFactorized that is convolved with uniform noise. Same as the tfc version with tweaks to allow more
    flexible sampling."""

    def __init__(self, name="MyNoisyDeepFactorized", **kwargs):
        super().__init__(MyDeepFactorized(**kwargs), name=name)
