"""
Custom Nerual Network Blocks
"""

import jax
from typing import Optional, Callable
from jax import Array, numpy as np

from einops import rearrange
from flax.linen import (
    Dense,
    Module,
    silu,
    log_softmax,
    softmax,
    SelfAttention,
    MultiHeadDotProductAttention,
    compact,
    Sequential,
)
from flax.linen.attention import (
    functools,
    DenseGeneral,
    lax,
    jnp,
    combine_masks,
    merge_param,
)
from typing import Union, Sequence, Dict, Any


class AsymmericMultiHeadDotProductAttention(MultiHeadDotProductAttention):
    """
    A modification of the MultiHeadDotProductAttention

    The main difference is that this blocks allow value and key use different input
    (# NOTE Isn't this supported in the original implementation?)
    This is important when we want to perform cross attentions
    """

    @compact
    def __call__(
        self,
        inputs_q: Array,
        inputs_k: Array,
        inputs_v: Array,
        mask: Optional[Array] = None,
        deterministic: Optional[bool] = None,
    ):
        """Applies multi-head dot product attention on the input data.

        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.

        Args:
          inputs_q: input queries of shape
            `[batch_sizes..., length, query/key features]`.
          inputs_k: key of shape
            `[batch_sizes..., length, query/key features]`.
          inputs_v: value of shape
            `[batch_sizes..., length, value features]`.
          mask: attention mask of shape
            `[batch_sizes..., num_heads, query_length, key/value_length]`.
            Attention weights are masked out if their corresponding mask value
            is `False`. Note that this is masked on QK^T
          deterministic: if false, the attention weight is masked randomly
            using dropout, whereas if true, the attention weights
            are deterministic.

        Returns:
          output of shape `[batch_sizes..., length, features]`.
        """
        features = self.out_features or inputs_v.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert (
            qkv_features % self.num_heads == 0
        ), "Memory dimension must be divisible by number of heads."
        head_dim = qkv_features // self.num_heads
        dense = functools.partial(
            DenseGeneral,
            axis=-1,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            features=(self.num_heads, head_dim),
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
            precision=self.precision,
            dot_general=self.qkv_dot_general,
        )

        if inputs_v.shape[-1] != qkv_features:
            value_dim = inputs_v.shape[-1]
            dense_value = functools.partial(
                DenseGeneral,
                axis=-1,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
                features=(self.num_heads, value_dim),
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
                use_bias=self.use_bias,
                precision=self.precision,
                dot_general=self.qkv_dot_general,
            )
        else:
            dense_value = dense

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (
            dense(name="query")(inputs_q),
            dense(name="key")(inputs_k),
            dense_value(name="value")(inputs_v),
        )

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable("cache", "cached_key")
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, key.shape, key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape
                )
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        "Autoregressive cache shape error, "
                        "expected query shape %s instead got %s."
                        % (expected_shape, query.shape)
                    )
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value, indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length),
                    ),
                )

        dropout_rng = None
        if self.dropout_rate > 0.0:  # Require `deterministic` only if using dropout.
            m_deterministic = merge_param(
                "deterministic", self.deterministic, deterministic
            )
            if not m_deterministic:
                dropout_rng = self.make_rng("dropout")
        else:
            m_deterministic = True

        # apply attention
        x = self.attention_fn(
            query,
            key,
            value,
            mask=mask,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout_rate,
            broadcast_dropout=self.broadcast_dropout,
            deterministic=m_deterministic,
            dtype=self.dtype,
            precision=self.precision,
        )  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = DenseGeneral(
            features=features,
            axis=(-2, -1),
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            dot_general=self.out_dot_general,
            name="out",  # type: ignore[call-arg]
        )(x)
        return out


class CrossAttention(AsymmericMultiHeadDotProductAttention):
    in_ch: int = np.inf  # tricky setting to avoid non-default arg

    @compact
    def __call__(
        self,
        inputs_q: Array,
        q_mask: Optional[Array],  # type: ignore
        inputs_k: Array,
        k_mask: Array,
        inputs_v: Array,
        deterministic: Optional[bool] = None,
    ):
        """Applies multi-head dot product cross-attention on the input data.

        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.

        Args:
          inputs_q: input queries of shape
            `[batch_sizes..., length, features]`.
          mask: attention mask of shape
            `[batch_sizes..., num_heads, query_length, key/value_length]`.
            Attention weights are masked out if their corresponding mask value
            is `False`.
          deterministic: if false, the attention weight is masked randomly
            using dropout, whereas if true, the attention weights
            are deterministic.

        Returns:
          output of shape `[batch_sizes..., length, features]`.
        """

        projs_q = Dense(features=self.in_ch)(inputs_q) * q_mask[..., None]
        projs_k = Dense(features=self.in_ch)(inputs_k)
        projs_v = Dense(features=self.in_ch)(inputs_v)
        mask = q_mask.reshape(-1, 1) & k_mask
        return super().__call__(
            inputs_q=projs_q,
            inputs_k=projs_k,
            inputs_v=projs_v,
            mask=mask,
            deterministic=deterministic,
        )


class MaskSequential(Sequential):
    """Applies a linear chain of Modules.

    Meant to be used only for the simple case of fusing together callables where
    the input of a particular module/op is the output of the previous one.

    Modules will be applied in the order that they are passed in the constructor.

    The apply() method of Sequential accepts any input and forwards it to the
    first module it contains. It chains the output sequentially to the input of
    the next module and returns the output of the final module.

    Example usage::

      class Foo(nn.Module):
        feature_sizes: Sequence[int]

        @nn.compact
        def __call__(self, x):
          return nn.Sequential([nn.Dense(4),
                                nn.relu,
                                nn.Dense(2),
                                nn.log_softmax])(x)

    This combinator supports also layers that return multiple outputs if returned
    as a tuple or a dictionary.

    Example usage::

      class CrossAttentionBlock(nn.Module):
        num_heads: int = 2
        qkv_features: int = 16

        @nn.compact
        def __call__(self, query, key_value):
          output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
                                                                    key_value)
          output = nn.Dense(self.qkv_features)(output)
          return dict(query=output, key_value=key_value)  # also works for tuples

      class CrossAttentionNetwork(nn.Module):
        num_layers: Sequence[int]

        @nn.compact
        def __call__(self, x):
          return nn.Sequential([CrossAttentionBlock() for _ in
                                range(self.num_layers)])(query, key_value)
    """

    layers: Sequence[Callable[..., Any]]

    def __call__(self, *args, mask: Optional[bool] = None, **kwargs):
        if not self.layers:
            raise ValueError(f"Empty Sequential module {self.name}.")

        outputs = self.layers[0](*args, mask=mask, **kwargs)
        for layer in self.layers[1:]:
            if isinstance(outputs, tuple):
                outputs = layer(*outputs, mask=mask)
            elif isinstance(outputs, Dict):
                outputs = layer(**outputs, mask=mask)
            else:
                outputs = layer(outputs, mask=mask)
        return outputs
