"""
Encoder only layer used in lra
"""

import functools
import math
import jax
import numpy as np
from flax import linen as nn
from jax import numpy as jnp


from .latte import BidLatteConv, BidLatteRopeConv, BidLatte, RopeBidLatteMachSliding
from ..xPos import XPos
from latte_trans.config import LRATaskConfig, ATT_TYPE
from ..layers import (
    RMSNorm,
    GatedMLP,
    Conv1D,
    PositionalEncoding,
    RopeEmbeds,
    create_rope,
    Embedder,
)


def get_mixing_layer(config: LRATaskConfig, dtype):
    match config.attention_type:
        case "latte_convQR_bid":
            if config.embed_type in ["xpos", "rope"]:
                return BidLatteRopeConv(
                    config=config,
                    dtype=dtype,
                )
            else:
                return BidLatteConv(
                    config=config,
                    dtype=dtype,
                )
        case "latte_mach_sliding_bid":
            return RopeBidLatteMachSliding(
                config=config,
                dtype=dtype,
            )
        case "latte_bid":
            if config.embed_type in ["absolute", "none"]:
                return BidLatte(
                    config=config,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")


class GLUBlock(nn.Module):
    config: LRATaskConfig
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X, train: bool = False, **kwargs) -> jnp.array:
        "X: B L D"
        if self.config.batchnorm:
            norm1 = nn.BatchNorm(use_running_average=not train, momentum=0.9)
        else:
            norm1 = nn.LayerNorm()

        lru = get_mixing_layer(self.config, self.dtype)
        drop = nn.Dropout(
            self.config.dropout,
            deterministic=not train,
        )

        out = nn.Dense(features=self.config.hidden_dim, use_bias=True)
        out2 = nn.Dense(features=self.config.hidden_dim, use_bias=True)

        skip = X
        if self.config.prenorm:
            X = norm1(X)
        X = lru(X, train=train, **kwargs)
        X = X + skip
        skip = X
        # full glu
        X = drop(nn.gelu(X))
        X = out(X) * nn.sigmoid(out2(X))
        X = drop(X)
        X = X + skip
        if not self.config.prenorm:
            X = norm1(X)
        return X


class TransBlock(nn.Module):
    """
    Implements a standard transformer block where the attention layer is replaced with mine
    """

    config: LRATaskConfig
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:

            out: jnp.array(BTD) - transformed output sequence
        """
        if self.config.batchnorm:
            norm1 = nn.BatchNorm(
                use_running_average=not train, momentum=0.9, dtype=self.dtype
            )
            norm2 = nn.BatchNorm(
                use_running_average=not train, momentum=0.9, dtype=self.dtype
            )
        else:
            norm1 = nn.LayerNorm(dtype=self.dtype)
            norm2 = nn.LayerNorm(dtype=self.dtype)

        lru = get_mixing_layer(self.config, self.dtype)

        # Two - layer MLP
        mlp = [
            nn.Dense(
                4 * self.config.hidden_dim,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(
                    stddev=self.config.initializer_range
                ),
            ),
            functools.partial(nn.gelu, approximate=False),
            nn.Dense(
                self.config.hidden_dim,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(
                    stddev=self.config.initializer_range
                ),
            ),
            nn.Dropout(self.config.dropout, deterministic=not train),
        ]

        drop = nn.Dropout(self.config.dropout, deterministic=not train)
        skip = X
        if self.config.prenorm:
            X = norm1(X)
        X = lru(X, train=train, **kwargs)  # apply a mixing layer, like attention
        X = skip + drop(X)
        if not self.config.prenorm:
            X = norm1(X)
        # MLP part
        skip = X
        if self.config.prenorm:
            X = norm2(X)
        for l in mlp:
            X = l(X)
        X = skip + X
        if not self.config.prenorm:
            X = norm2(X)
        return X


class SotaTransBlock(nn.Module):
    config: LRATaskConfig
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        x,
        train: bool = False,
        attention_mask=None,
    ):
        pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mix_layer = get_mixing_layer(self.config, self.dtype)
        mlp_pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mlp = GatedMLP(
            hidden_dim=self.config.hidden_dim,
            intermediate_dim=4 * self.config.hidden_dim,
            initializer_range=self.config.initializer_range,
            dtype=self.dtype,
        )

        residual = x
        x = pre_norm(x)
        x = mix_layer(x, attention_mask=attention_mask, train=train)

        # jax.debug.print("Mix block nan: {x}", x=jnp.isnan(x).any())
        residual = x + residual
        x = mlp_pre_norm(residual)
        x = mlp(x)
        return x + residual


class ConvEmbed(nn.Module):
    """
    Pass input through a convolution network
    """

    dropout: float
    hidden_dim: int

    @nn.compact
    def __call__(self, X: jnp.array, train: bool = False) -> jnp.array:
        """
        Args:
            X: (batch_size, (W*H), 1)
            train: bool. Used for dropout
        """
        conv_dims = (1, 24, 48, 96, self.hidden_dim)  # 192
        conv_layers = [
            nn.Conv(
                features=conv_dims[i + 1], kernel_size=(3, 3), strides=1, padding="SAME"
            )
            for i in range(0, len(conv_dims) - 1)
        ]
        norm = nn.LayerNorm()
        drop = nn.Dropout(self.dropout, deterministic=not train)

        batch_sz, seq_len, inp_ch = X.shape
        W = int(math.sqrt(seq_len))
        H = W
        X = X.reshape(batch_sz, W, H, 1)
        for l in conv_layers:
            X = l(X)
        X = X.reshape(batch_sz, seq_len, -1)
        X = norm(drop(X))
        return X


class TextImageEncoder(nn.Module):
    """
    Deals with images and text.
    """

    config: LRATaskConfig
    vocab_size: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, X: jnp.array, train: bool = False, attention_mask=None
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        # if self.vocab_size is None:  # images do not require embedding layer
        #     if self.config.conv_embed:
        #         embed = ConvEmbed(
        #             dropout=self.config.dropout, hidden_dim=self.config.hidden_dim
        #         )
        #     else:
        #         embed = nn.Dense(features=self.config.hidden_dim)  #
        # else:
        #     embed = nn.Embed(
        #         num_embeddings=self.vocab_size,
        #         features=self.config.hidden_dim,
        #         dtype=jnp.float32,
        #     )

        embed = nn.Dense(features=self.config.hidden_dim)  #
        if self.vocab_size is not None:
            X = jax.nn.one_hot(X, self.vocab_size)
        pos_embeds = None
        if self.config.embed_type == "absolute":  # absolute
            pos_embeds = PositionalEncoding(
                d_model=self.config.hidden_dim,
                max_len=self.config.pos_embed_max_len,
                dtype=self.dtype,
            )

        # drop_embed = nn.Dropout(self.config.dropout, deterministic=not train)
        ln = nn.LayerNorm()

        if self.config.block_type == "transformer":
            block = TransBlock
        elif self.config.block_type == "glu":
            block = GLUBlock
        elif self.config.block_type == "transformer-sota":
            block = SotaTransBlock
        else:
            raise IOError("Block type not supported")
        enc_layers = [
            block(config=self.config, dtype=self.dtype)
            for _ in range(self.config.nlayers)
        ]
        if not pos_embeds is None:
            X = pos_embeds(embed(X))
            # X = drop_embed(X)
        else:  # relative or nope
            X = embed(X)
            # X = drop_embed(X)

        if not self.config.prenorm:
            X = ln(X)
        for l in enc_layers:
            X = l(X, train, attention_mask=attention_mask)
        return X
