ValueNorm
========================================

This module defines a class named ValueNorm. 
This module is used for normalizing vectors of observations, specifically across certain dimensions (norm_axes). 
The normalization is performed using a running mean and running mean square with exponential decay.

.. raw:: html

    <br><hr>

.. py:class::
  xuance.torch.utils.value_norm.ValueNorm(input_shape, norm_axes, beta, per_element_update, epsilon)

  :param input_shape: The shape of the input data.
  :type input_shape: Sequence[int]
  :param norm_axes: axes along which normalization is performed.
  :type norm_axes: int
  :param beta: decay factor for exponential moving average.
  :type beta: float
  :param per_element_update: whether to update each element independently.
  :type per_element_update: bool
  :param epsilon: a small constant for numerical stability.
  :type epsilon: float

.. py:function::
  xuance.torch.utils.value_norm.ValueNorm.reset_parameters()

  Resets the running mean, running mean square, and debiasing term to zero.

.. py:function::
  xuance.torch.utils.value_norm.ValueNorm.running_mean_var()

  Computes the debiased mean and variance using the running mean, running mean square, and debiasing term.

  :return: the debiased mean and variance using the running mean.
  :rtype: np.ndarray, torch.Tensor

.. py:function::
  xuance.torch.utils.value_norm.ValueNorm.update(input_vector)

  Updates the running mean, running mean square, and debiasing term based on the input vector.

  :param input_vector: the input vector.
  :type input_vector: np.ndarray, torch.Tensor

.. py:function::
  xuance.torch.utils.value_norm.ValueNorm.normalize(input_vector)

  Normalizes the input vector using the computed mean and variance..

  :param input_vector: the input vector.
  :type input_vector: np.ndarray, torch.Tensor
  :return: the normalized vector.
  :rtype: np.ndarray, torch.Tensor

.. py:function::
  xuance.torch.utils.value_norm.ValueNorm.denormalize(input_vector)

  Transforms a normalized input vector back into the original distribution using the computed mean and variance.

  :param input_vector: the input vector.
  :type input_vector: np.ndarray, torch.Tensor
  :return: the denormalized vector.
  :rtype: np.ndarray, torch.Tensor

.. raw:: html

    <br><hr>

Source Code
-----------------

.. code-block:: python

    import numpy as np
    import torch
    import torch.nn as nn


    class ValueNorm(nn.Module):
        """ Normalize a vector of observations - across the first norm_axes dimensions"""

        def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5):
            super(ValueNorm, self).__init__()

            self.input_shape = input_shape
            self.norm_axes = norm_axes
            self.epsilon = epsilon
            self.beta = beta
            self.per_element_update = per_element_update

            self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False)
            self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False)
            self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False)

            self.reset_parameters()

        def reset_parameters(self):
            self.running_mean.zero_()
            self.running_mean_sq.zero_()
            self.debiasing_term.zero_()

        def running_mean_var(self):
            debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
            debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
            debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)
            return debiased_mean, debiased_var

        @torch.no_grad()
        def update(self, input_vector):
            if type(input_vector) == np.ndarray:
                input_vector = torch.from_numpy(input_vector)
            input_vector = input_vector.to(self.running_mean.device)  # not elegant, but works in most cases

            batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes)))
            batch_sq_mean = (input_vector ** 2).mean(dim=tuple(range(self.norm_axes)))

            if self.per_element_update:
                batch_size = np.prod(input_vector.size()[:self.norm_axes])
                weight = self.beta ** batch_size
            else:
                weight = self.beta

            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))

        def normalize(self, input_vector):
            # Make sure input is float32
            if type(input_vector) == np.ndarray:
                input_vector = torch.from_numpy(input_vector)
            input_vector = input_vector.to(self.running_mean.device)  # not elegant, but works in most cases

            mean, var = self.running_mean_var()
            out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]

            return out

        def denormalize(self, input_vector):
            """ Transform normalized data back into original distribution """
            if type(input_vector) == np.ndarray:
                input_vector = torch.from_numpy(input_vector)
            input_vector = input_vector.to(self.running_mean.device)  # not elegant, but works in most cases

            mean, var = self.running_mean_var()
            out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]

            out = out.cpu().numpy()

            return out
