"""Adapted from https://github.com/google/flax/blob/main/examples/nlp_seq/models.py"""

from functools import partial
from typing import Any, Optional, Dict

import chex
import jax.numpy as jnp
from flax import linen as nn
from flax.struct import dataclass, field
from src.models.mha import MultiHeadAttentionRoPE


@dataclass
class TransformerLayerConfig:
    """Global hyperparameters used to minimize obnoxious kwarg plumbing."""

    num_heads: int = 8
    emb_dim_per_head: int = 16
    mlp_dim_factor: float = 4.0
    dropout_rate: float = 0.0
    attention_dropout_rate: float = 0.0
    use_bias: bool = False
    activation: str = "silu"
    dtype: Any = jnp.float32
    mha_norm_type: str = "rms_norm"
    emb_dim: int = field(default=None)

    def __post_init__(self):
        object.__setattr__(self, "emb_dim", self.num_heads * self.emb_dim_per_head)


@dataclass
class UseUniqueIdConfig:
    """Configuration for unique ID usage."""

    active: bool = False
    num_unique_ids: int = 0


@dataclass
class RopeEmbeddingsConfig:
    """Configuration for rope embeddings."""

    active: bool = True
    max_freq: float = 10.0


@dataclass
class LearnedPositionEmbeddingsConfig:
    """Configuration for learned position embeddings."""

    active: bool = False
    scale: bool = False


@dataclass
class PositionEmbeddingsConfig:
    """Configuration for position embeddings."""

    rope_embeddings: RopeEmbeddingsConfig = field(default_factory=RopeEmbeddingsConfig)
    learned_position_embeddings: LearnedPositionEmbeddingsConfig = field(
        default_factory=LearnedPositionEmbeddingsConfig
    )


@dataclass
class EncoderTransformerConfig:
    """Global hyperparameters used to minimize obnoxious kwarg plumbing."""

    transformer_layer: TransformerLayerConfig = TransformerLayerConfig()
    vocab_size: int = 10
    output_vocab_size: int = 10
    num_layers: int = 2
    latent_dim: int = 32
    max_rows: int = 30
    max_cols: int = 30
    position_embeddings: PositionEmbeddingsConfig = field(default_factory=PositionEmbeddingsConfig)
    use_unique_id: UseUniqueIdConfig = field(default_factory=UseUniqueIdConfig)
    dtype: jnp.dtype = field(default=None)
    emb_dim: int = field(default=None)
    max_len: int = field(default=None)
    variational: bool = field(default=True)

    def __post_init__(self):
        object.__setattr__(self, "dtype", self.transformer_layer.dtype)
        object.__setattr__(self, "emb_dim", self.transformer_layer.emb_dim)
        object.__setattr__(self, "max_len", self.max_rows * self.max_cols)


@dataclass
class DecoderTransformerConfig:
    """Global hyperparameters used to minimize obnoxious kwarg plumbing."""

    transformer_layer: TransformerLayerConfig = TransformerLayerConfig()
    vocab_size: int = 10
    output_vocab_size: int = 10
    num_layers: int = 2
    max_rows: int = 30
    max_cols: int = 30
    position_embeddings: PositionEmbeddingsConfig = field(default_factory=PositionEmbeddingsConfig)
    next_position_embeddings: bool = True
    next_position_embeddings_new_input_embeds: bool = False
    logits_projection_bias: bool = False
    dtype: jnp.dtype = field(default=None)
    emb_dim: int = field(default=None)
    max_len: int = field(default=None)

    def __post_init__(self):
        object.__setattr__(self, "dtype", self.transformer_layer.dtype)
        object.__setattr__(self, "emb_dim", self.transformer_layer.emb_dim)
        object.__setattr__(self, "max_len", self.max_rows * self.max_cols)


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

    Attributes:
        config: TransformerConfig dataclass containing hyperparameters.
    """

    config: TransformerLayerConfig

    def setup(self) -> None:
        if self.config.activation == "relu":
            self.activation = nn.relu
        elif self.config.activation == "silu":
            self.activation = nn.silu
        else:
            raise ValueError(f"Unsupported activation: {self.config.activation}")

    @nn.compact
    def __call__(self, inputs: chex.Array, dropout_eval: bool) -> chex.Array:
        """Applies Transformer MlpBlock module."""
        config = self.config
        x = inputs
        x = nn.Dense(int(config.mlp_dim_factor * config.emb_dim), config.use_bias, config.dtype)(x)
        x = self.activation(x)
        x = nn.Dense(inputs.shape[-1], config.use_bias, config.dtype)(x)
        x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=dropout_eval)
        return x


class TransformerLayer(nn.Module):
    """Transformer encoder layer.

    Attributes:
        config: TransformerConfig dataclass containing hyperparameters.
    """

    config: TransformerLayerConfig
    position_embeddings_config: PositionEmbeddingsConfig

    @nn.compact
    def __call__(
        self,
        embeddings: chex.Array,
        pos_embeddings: chex.Array,
        dropout_eval: bool,
        pad_mask: Optional[chex.Array] = None,
    ) -> chex.Array:
        """Applies TransformerLayer module.

        Args:
            embeddings: input embeddings.
            dropout_eval: if false dropout is applied otherwise it is not.
            pad_mask: mask to apply on the inputs to avoid attending to padding tokens.

        Returns:
            output after transformer encoder layer.
        """
        config = self.config

        # Attention block.
        assert embeddings.ndim >= 3
        x = nn.LayerNorm(dtype=config.dtype, use_bias=config.use_bias, use_scale=False)(embeddings)

        x = MultiHeadAttentionRoPE(
            num_heads=config.num_heads,
            dtype=config.dtype,
            dropout_rate=config.attention_dropout_rate,
            use_bias=config.use_bias,
            attention_fn=partial(nn.dot_product_attention, force_fp32_for_softmax=True),
            norm_type=config.mha_norm_type,
            rope=self.position_embeddings_config.rope_embeddings.active,
            rope_max_freq=self.position_embeddings_config.rope_embeddings.max_freq,
        )(inputs_q=x, mask=pad_mask, pos_embeddings=pos_embeddings, deterministic=dropout_eval)
        residuals = nn.Dropout(rate=config.dropout_rate)(x, deterministic=dropout_eval)
        embeddings += residuals

        # MLP block.
        x = nn.LayerNorm(dtype=config.dtype, use_bias=config.use_bias, use_scale=False)(embeddings)
        residuals = MlpBlock(config=config)(x, dropout_eval=dropout_eval)
        embeddings += residuals
        return embeddings


def create_coordinate_encoding_encoder(sequence_len, row_size, col_size):
    """
    Generates a coordinate encoding for a grid-based input sequence.

    This function maps positions within two grids to their respective x and y coordinates.
    The coordinate encoding accounts for specific offsets to embed different sections of the
    sequence appropriately (e.g., grid1, grid2, context and shapes).

    Note this is tailored to encoder ordering of the sequence. Any none grid information are given coordinate (0,0)

    Args:
        sequence_len: The total sequence length of the input.
        row_size: The number of rows in each grid.
        col_size: The number of columns in each grid.

    Returns:
        pos: An array of shape (sequence_len, 2) where each element represents the
            (x, y) coordinate for a position in the sequence.
    """

    grid_size = row_size * col_size

    # Create initial zero array for all positions
    pos = jnp.zeros((sequence_len, 2))

    # First 2 positions are shape (0,0)
    # Then grid1 (900 positions)
    grid1_start = 5
    grid1_pos = jnp.stack(
        [
            jnp.arange(grid_size) % col_size,  # x coordinates
            jnp.arange(grid_size) // row_size,  # y coordinates
        ],
        axis=-1,
    )

    grid2_start = 5 + grid_size
    grid2_pos = jnp.stack(
        [
            jnp.arange(grid_size) % col_size,  # x coordinates
            jnp.arange(grid_size) // row_size,  # y coordinates
        ],
        axis=-1,
    )

    # Update positions at the correct indices
    pos = pos.at[grid1_start : grid1_start + grid_size].set(grid1_pos)
    pos = pos.at[grid2_start : grid2_start + grid_size].set(grid2_pos)

    return pos


def create_coordinate_encoding_decoder(sequence_len, row_size, col_size):
    """
    Generates a coordinate encoding for a grid-based input sequence.

    This function maps positions within two grids to their respective x and y coordinates.
    The coordinate encoding accounts for specific offsets to embed different sections of the
    sequence appropriately (e.g., grid1, grid2, context and shapes).

    Note this is tailored to encoder ordering of the sequence. Any none grid information are given coordinate (0,0)

    Args:
        sequence_len: The total sequence length of the input.
        row_size: The number of rows in each grid.
        col_size: The number of columns in each grid.

    Returns:
        pos: An array of shape (sequence_len, 2) where each element represents the
            (x, y) coordinate for a position in the sequence.
    """

    context_size = (
        sequence_len - 2 * row_size * col_size - 4
    )  # Remove the 2 girds and the 2 shapes (each row and col)

    grid_size = row_size * col_size

    # Create initial zero array for all positions
    pos = jnp.zeros((sequence_len, 2))

    # First 2 positions are shape (0,0)
    # Then grid1 (900 positions)
    grid1_start = 2
    grid1_pos = jnp.stack(
        [
            jnp.arange(grid_size) % col_size,  # x coordinates
            jnp.arange(grid_size) // row_size,  # y coordinates
        ],
        axis=-1,
    )

    grid2_start = 2 + grid_size + context_size
    grid2_pos = jnp.stack(
        [
            jnp.arange(grid_size) % col_size,  # x coordinates
            jnp.arange(grid_size) // row_size,  # y coordinates
        ],
        axis=-1,
    )

    # Update positions at the correct indices
    pos = pos.at[grid1_start : grid1_start + grid_size].set(grid1_pos)
    pos = pos.at[grid2_start : grid2_start + grid_size].set(grid2_pos)

    return pos
