"""
Flax NNX version of
https://github.com/google-deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
"""

from typing import NamedTuple

from flax import nnx
from flax.typing import Dtype
from jax import Array, lax, numpy as jnp
from jax.nn import one_hot
from jax.nn.initializers import lecun_uniform
from optax import squared_error

from offline.lbp.tc.modules.moving_averages import ExponentialMovingAverage


class VQResults(NamedTuple):
    encoding_indices: Array
    encodings: Array
    loss: Array
    perplexity: Array
    quantized: Array


class VQVariable(nnx.Variable[Array]):
    pass


class VectorQuantizerEMA(nnx.Module):
    """NNX module representing the VQ-VAE layer.

    Implements a slightly modified version of the algorithm presented in
    'Neural Discrete Representation Learning' by van den Oord et al.
    https://arxiv.org/abs/1711.00937

    This module uses exponential moving average to update the embeddding
    vectors instead of an auxiliary loss. This has the advantage that the
    embedding updates are independent of the choice of optimizer (SGD, RMSProp,
    Adam, K-Fac, ...) used for the encoder, decoder and other parts of the
    architecture. For most experiments the EMA version trains faster than the
    non-EMA version.

    Input any array to be quantized. Last dimension will be used as space in
    which to quantize. All other dimensions will be flattened and will be seen
    as different examples to quantize.

    The output array will have the same shape as the input.

    For example an array with shape (16, 32, 32, 64) will be reshaped into
    (16384, 64) and all 16384 vectors (each of 64 dimensions) will be quantized
    independently.

    Attributes:
        commitment_cost: scalar which controls the weighting of the loss terms.
        decay: decay for the moving averages.
        embedding_dim:
            integer representing the dimensionality of the tensors in the
            quantized space. Inputs to the modules must be in this format as
            well.
        epsilon: small float constant to avoid numerical instability.
    """

    def __init__(
        self,
        decay: float,
        embedding_dim: int,
        num_embeddings: int,
        rngs: nnx.Rngs,
        dtype: Dtype = jnp.float32,
        epsilon: float = 1e-5,
        update_ema: bool = True,
    ):
        """Initialize a VQ-VAE EMA module.

        Args:
            decay:
                float between 0 and 1, controls the speed of the Exponential
                Moving Averages.
            embedding_dim:
                integer representing the dimensionality of the arrays in the
                quantized space. Inputs to the modules must be in this format as
                well.
            num_embeddings: the number of vectors in the quantized space.
            rngs: NNX Rng for initializing the embeddings.
            dtype: dtype for the embeddings variable, defaults to jnp.float32.
            epsilon:
                small constant to aid numerical stability, defaults to 1e-5.
            update_ema:
                if False, ema stats will not be updated, defaults to True.
        """
        self.decay = decay
        self.embedding_dim = embedding_dim
        self.epsilon = epsilon
        embeddings = lecun_uniform()(
            dtype=dtype,
            key=rngs.params(),
            shape=(embedding_dim, num_embeddings),
        )
        self.embeddings = VQVariable(embeddings)
        self.ema_cluster_size = ExponentialMovingAverage(
            decay=decay, values=jnp.zeros(num_embeddings, dtype=dtype)
        )
        self.ema_dw = ExponentialMovingAverage(decay=decay, values=embeddings)
        self.update_ema = update_ema

    def __call__(self, inputs: Array) -> VQResults:
        """Connects the module to some inputs.

        Args:
            inputs:
                final dimension must be equal to embedding_dim. All other
                leading dimensions will be flattened and treated as a
                large batch.

        Returns:
            NamedTuple containing the following values:
                encoding_indices:
                    Array containing the discrete encoding indices, i.e., which
                    element of the quantized space each input element was
                    mapped to.
                encodings:
                    Array containing the discrete encodings, i.e., which element
                    of the quantized space each input element was mapped to.
                loss: Array containing the loss to optimize.
                perplexity: Array containing the perplexity of the encodings.
                quantize: Array containing the quantized version of the input.
        """
        flat_inputs = jnp.reshape(inputs, (-1, self.embedding_dim))
        embeddings = self.embeddings.value
        distances = (
            jnp.sum(flat_inputs**2, 1, keepdims=True)
            - 2 * jnp.matmul(flat_inputs, embeddings)
            + jnp.sum(embeddings**2, 0, keepdims=True)
        )
        encoding_indices = jnp.argmax(-distances, 1)
        encodings = one_hot(encoding_indices, embeddings.shape[1])
        encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
        quantized = self.quantize(encoding_indices)
        e_latent_loss = jnp.mean(
            squared_error(inputs, lax.stop_gradient(quantized))
        )

        if self.update_ema:
            self.ema_cluster_size.update(values=jnp.sum(encodings, axis=0))
            updated_ema_cluster_size = self.ema_cluster_size.compute()
            dw = jnp.matmul(flat_inputs.T, encodings)
            self.ema_dw.update(values=dw)
            updated_ema_dw = self.ema_dw.compute()
            n = jnp.sum(updated_ema_cluster_size)
            updated_ema_cluster_size = (
                (updated_ema_cluster_size + self.epsilon)
                / (n + embeddings.shape[1] * self.epsilon)
                * n
            )
            normalized_updated_ema_w = updated_ema_dw / jnp.reshape(
                updated_ema_cluster_size, (1, -1)
            )
            self.embeddings.value = normalized_updated_ema_w

        quantized = inputs + lax.stop_gradient(quantized - inputs)
        avg_probs = jnp.mean(encodings, 0)
        perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))
        return VQResults(
            encoding_indices=encoding_indices,
            encodings=encodings,
            loss=e_latent_loss,
            perplexity=perplexity,
            quantized=quantized,
        )

    def quantize(self, encoding_indices: Array):
        """Returns embedding for a batch of indices"""
        quantized = jnp.take(self.embeddings.T, encoding_indices, axis=0)
        return quantized
