"""Implementation of the original VQ VAE"""
#  type: ignore
from typing import Any, Callable, Dict, Optional

import haiku as hk
import jax
import jax.numpy as jnp


class ResidualStack(hk.Module):
    def __init__(
        self,
        num_hiddens: int,
        num_residual_layers: int,
        num_residual_hiddens: int,
        name: Optional[str] = None,
    ):
        super(ResidualStack, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._layers = []
        for i in range(num_residual_layers):
            conv3 = hk.Conv2D(
                output_channels=num_residual_hiddens,
                kernel_shape=(3, 3),
                stride=(1, 1),
                name="res3x3_%d" % i,
            )
            conv1 = hk.Conv2D(
                output_channels=num_hiddens,
                kernel_shape=(1, 1),
                stride=(1, 1),
                name="res1x1_%d" % i,
            )
            self._layers.append((conv3, conv1))

    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        h = inputs
        for conv3, conv1 in self._layers:
            conv3_out = conv3(jax.nn.relu(h))
            conv1_out = conv1(jax.nn.relu(conv3_out))
            h += conv1_out
        return jax.nn.relu(h)  # Resnet V1 style


class Encoder(hk.Module):
    def __init__(
        self,
        num_hiddens: int,
        num_residual_layers: int,
        num_residual_hiddens: int,
        name: Optional[str] = None,
    ):
        super(Encoder, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._enc_1 = hk.Conv2D(
            output_channels=self._num_hiddens // 2,
            kernel_shape=(4, 4),
            stride=(2, 2),
            name="enc_1",
        )
        self._enc_2 = hk.Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(4, 4),
            stride=(2, 2),
            name="enc_2",
        )
        self._enc_3 = hk.Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(3, 3),
            stride=(1, 1),
            name="enc_3",
        )
        self._residual_stack_1 = ResidualStack(
            self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        h = jax.nn.relu(self._enc_1(x))
        h = jax.nn.relu(self._enc_2(h))
        h = jax.nn.relu(self._enc_3(h))
        h = self._residual_stack_1(h)
        return h


class Decoder(hk.Module):
    def __init__(
        self,
        num_hiddens: int,
        num_residual_layers: int,
        num_residual_hiddens: int,
        name: Optional[str] = None,
    ):
        super(Decoder, self).__init__(name=name)
        self._num_hiddens = num_hiddens
        self._num_residual_layers = num_residual_layers
        self._num_residual_hiddens = num_residual_hiddens

        self._dec_1 = hk.Conv2D(
            output_channels=self._num_hiddens,
            kernel_shape=(3, 3),
            stride=(1, 1),
            name="dec_1",
        )
        self._residual_stack_1 = ResidualStack(
            self._num_hiddens, self._num_residual_layers, self._num_residual_hiddens
        )
        self._dec_2 = hk.Conv2DTranspose(
            output_channels=self._num_hiddens // 2,
            # output_shape=None,
            kernel_shape=(4, 4),
            stride=(2, 2),
            name="dec_2",
        )
        self._dec_3 = hk.Conv2DTranspose(
            output_channels=3,
            # output_shape=None,
            kernel_shape=(4, 4),
            stride=(2, 2),
            name="dec_3",
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        h = self._dec_1(x)
        h = self._residual_stack_1(h)
        h = jax.nn.relu(self._dec_2(h))
        x_recon = self._dec_3(h)
        return x_recon


class VQVAEModel(hk.Module):
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        vqvae: hk.nets.VectorQuantizer,
        pre_vq_conv1: Any,
        data_variance: jnp.ndarray,
        name: Optional[str] = None,
    ):
        super(VQVAEModel, self).__init__(name=name)
        self._encoder = encoder
        self._decoder = decoder
        self._vqvae = vqvae
        self._pre_vq_conv1 = pre_vq_conv1
        self._data_variance = data_variance

    def __call__(
        self, inputs: jnp.ndarray, is_training: bool
    ) -> Dict[str, jnp.ndarray]:
        z = self._pre_vq_conv1(self._encoder(inputs))
        vq_output = self._vqvae(z, is_training=is_training)
        x_recon = self._decoder(vq_output["quantize"])
        recon_error = jnp.mean((x_recon - inputs) ** 2) / self._data_variance
        loss = recon_error + vq_output["loss"]
        return {
            "z": z,
            "x_recon": x_recon,
            "loss": loss,
            "recon_error": recon_error,
            "vq_output": vq_output,
        }


def build_vq_vae_fn(
    num_hiddens: int,
    num_residual_hiddens: int,
    num_residual_layers: int,
    embedding_dim: int,
    num_embeddings: int,
    decay: float,
    vq_use_ema: bool,
    commitment_cost: float,
    train_data_variance: float,
) -> Callable[[jnp.ndarray, bool], Any]:
    """
    Builds the VQ-VAE
    """

    def vq_vae_fn(data: jnp.ndarray, is_training: bool) -> jnp.ndarray:
        encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
        decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
        pre_vq_conv1 = hk.Conv2D(
            output_channels=embedding_dim,
            kernel_shape=(1, 1),
            stride=(1, 1),
            name="to_vq",
        )

        if vq_use_ema:
            vq_vae = hk.nets.VectorQuantizerEMA(
                embedding_dim=embedding_dim,
                num_embeddings=num_embeddings,
                commitment_cost=commitment_cost,
                decay=decay,
            )
        else:
            vq_vae = hk.nets.VectorQuantizer(
                embedding_dim=embedding_dim,
                num_embeddings=num_embeddings,
                commitment_cost=commitment_cost,
            )

        model = VQVAEModel(
            encoder, decoder, vq_vae, pre_vq_conv1, data_variance=train_data_variance
        )
        return model(data, is_training)

    return vq_vae_fn
