from typing import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp


class Stack(nn.Module):
    """Stack of pooling and convolutional blocks with residual connections."""

    stack_size: int
    layer_norm: bool

    @nn.compact
    def __call__(self, x):
        initializer = nn.initializers.variance_scaling(
            scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
        )
        x = nn.Conv(
            features=self.stack_size,
            kernel_size=(3, 3),
            kernel_init=initializer,
        )(x)
        x = nn.max_pool(x, window_shape=(3, 3), padding="SAME", strides=(2, 2))

        for _ in range(2):
            block_input = x
            if self.layer_norm:
                x = nn.LayerNorm()(x)
            x = nn.relu(x)
            x = nn.relu(nn.Conv(features=self.stack_size, kernel_size=(3, 3))(x))
            x = nn.Conv(features=self.stack_size, kernel_size=(3, 3))(x)
            x += block_input

        return x


class QuantileEmbedding(nn.Module):
    n_features: int = 7744
    quantile_embedding_dim: int = 64

    @nn.compact
    def __call__(self, key, n_quantiles):
        initializer = nn.initializers.variance_scaling(
            scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
        )

        quantiles = jax.random.uniform(key, shape=(n_quantiles, 1))
        arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape(
            (1, self.quantile_embedding_dim)
        )

        quantile_embedding = nn.Dense(
            features=self.n_features, kernel_init=initializer
        )(jnp.cos(jnp.pi * quantiles @ arange))
        # output (n_quantiles, n_features) | (n_quantiles)
        return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1))


class IQNNet(nn.Module):
    features: Sequence[int]
    architecture_type: str
    n_actions: int
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x, key, n_quantiles):
        # NO Support for LN in CNN Module
        if self.architecture_type == "cnn":
            initializer = nn.initializers.variance_scaling(
                scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
            )
            idx_feature_start = 3
            x = nn.relu(
                nn.Conv(
                    features=self.features[0],
                    kernel_size=(8, 8),
                    strides=(4, 4),
                    kernel_init=initializer,
                )(jnp.array(x, ndmin=4) / 255.0)
            )
            x = nn.relu(
                nn.Conv(
                    features=self.features[1],
                    kernel_size=(4, 4),
                    strides=(2, 2),
                    kernel_init=initializer,
                )(x)
            )
            x = nn.relu(
                nn.Conv(
                    features=self.features[2],
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    kernel_init=initializer,
                )(x)
            )
            x = x.reshape((x.shape[0], -1))
        elif self.architecture_type == "impala":
            initializer = nn.initializers.variance_scaling(
                scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform"
            )
            idx_feature_start = 3
            x = Stack(self.features[0], self.layer_norm)(jnp.array(x, ndmin=4) / 255.0)
            x = Stack(self.features[1], self.layer_norm)(x)
            x = Stack(self.features[2], self.layer_norm)(x)
            if self.layer_norm:
                x = nn.LayerNorm()(x)
            x = nn.relu(x)
            x = x.reshape((x.shape[0], -1))
        elif self.architecture_type == "fc":
            # Not used for IQN, but kept for consistency with other architectures
            initializer = nn.initializers.lecun_normal()
            idx_feature_start = 0

        x = jnp.squeeze(x)

        quantiles_features, quantiles = QuantileEmbedding()(key, n_quantiles)

        # mapping over the quantiles | output (n_quantiles, n_features)
        x = jax.vmap(
            lambda quantile_features, state_features_: quantile_features
            * state_features_,
            in_axes=(0, None),
        )(quantiles_features, x)

        for idx_layer in range(idx_feature_start, len(self.features)):
            x = nn.Dense(self.features[idx_layer], kernel_init=initializer)(x)
            if self.layer_norm:
                x = nn.LayerNorm()(x)
            x = nn.relu(x)

        return nn.Dense(self.n_actions, kernel_init=initializer)(x), quantiles
