"""Implementation of Welford's online variance estimator."""
import tensorflow as tf


class VarianceAccumulator:
    # The mean and M2 tensors are float64 to help reduce
    # numerical issues.

    # Aggregates the number of samples seen so far.
    count: int = None

    # Accumulates the mean of the entire dataset.
    mean: tf.Tensor = None

    # Aggregates the squared distance from the mean.
    M2: tf.Tensor = None

    def _initialize(self, x):
        self.count = 1
        self.mean = tf.Variable(tf.cast(x, tf.float64), trainable=False)
        self.M2 = tf.Variable(tf.zeros_like(x, dtype=tf.float64), trainable=False)

    def _maybe_initialize(self, x):
        if self.count is None:
            self._initialize(x)
            return True
        return False

    def update(self, x):
        # The argument x corresponds to a single, unbatched sample.
        did_initialization_now = self._maybe_initialize(x)
        if did_initialization_now:
            return
        x = tf.cast(x, tf.float64)
        self.count += 1
        delta = x - self.mean
        self.mean.assign_add(delta / self.count)
        delta2 = x - self.mean
        self.M2.assign_add(delta * delta2)

    def batch_update(self, x):
        # The first dimension of x is the batch dimension.
        # TODO: Do single update rather than iterating over the samples in the batch.
        for sample in x:
            self.update(sample)

    @property
    def variance(self) -> tf.Tensor:
        return self.M2 / (self.count - 1)
