from typing import Type
from functools import partial
import math
import jax
from flax import linen as nn
from jax import numpy as jnp, lax
from ..xPos import XPos
from ..layers import RopeEmbeds
from latte_trans.config import Config

# from .init_jax import dense_init


def repeat_kv(hidden_states: jax.Array, n_rep: int) -> jax.Array:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = jnp.broadcast_to(
        hidden_states[:, :, None, :, :],
        shape=(batch, num_key_value_heads, n_rep, slen, head_dim),
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def apply_rotation(rel_pos, k, q):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: Union[jnp.array(B, H, T, D], Tuple] -> TBHD
            Half sin & second half cos
        k,q: jnp.array(B,H,T,D) # B, nh, T, hs
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    # if isinstance(rel_pos, tuple): # XPos implementation
    #     sin, cos, scale = rel_pos
    #     k = apply_rotary_pos_emb(k, sin, cos, scale = 1 / scale)
    #     q = apply_rotary_pos_emb(q, sin[-q.shape[2]:], cos[-q.shape[2]:], scale = scale[-q.shape[2]:])
    if isinstance(rel_pos, XPos):
        k = rel_pos(k, offset=0, downscale=True)
        q = rel_pos(q, offset=0, downscale=False)
    else:
        q = rel_pos(q)
        k = rel_pos(k)
    return k, q


def attention(attn_dropout, q, k, v, mask):
    # manual implementation of attention
    att = (q @ jnp.swapaxes(k, -2, -1)) * (1.0 / math.sqrt(k.shape[-1]))

    # att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
    att = jnp.where(mask == 0, -9e15, att)
    att = jax.nn.softmax(att, axis=-1)

    att = attn_dropout(att)
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    return y


class CausalSelfAttention(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        # key, query, value projections for all heads, but in a batch
        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads
        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )
        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )
        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T).astype(jnp.bool_)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # q, k, v = jnp.split(c_attn(X), 3, axis=2)
        q, k, v = q_proj(X), k_proj(X), v_proj(X)

        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        y = attention(attn_dropout, q, k, v, mask=bias)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)  # resid_dropout(c_proj(y))
        return y


class BidirectionalAttention(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, src: jnp.array, train: bool, attention_mask: jnp.array, **kwargs
    ) -> jnp.array:
        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads
        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )
        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # expand source masK BS -> BHTS
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = jnp.broadcast_to(attention_mask, shape=(B, 1, S, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = q_proj(src), k_proj(src), v_proj(src)

        q = q.reshape(B, S, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        k = k.reshape(B, S, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        y = attention(attn_dropout, q, k, v, mask=attention_mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, S, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)
        return y


class RopeBidAtt(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, src: jnp.array, attention_mask: jnp.array, train: bool, **kwargs
    ) -> jnp.array:
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )
        k_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # expand source masK BS -> BHTS
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = jnp.broadcast_to(attention_mask, shape=(B, 1, S, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = q_proj(src), k_proj(src), v_proj(src)

        q = q.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k = k.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        sin_pos = rot_embeds[:S, :]  # T D
        k, q = apply_rotation(sin_pos, k, q)  # BHTD

        y = attention(attn_dropout, q, k, v, mask=attention_mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, S, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)
        return y


class RopeCausal(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads

        num_key_value_groups = self.config.nheads // num_key_value_heads

        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=head_dim,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=head_dim,
                scale_base=self.config.max_seq_len,
            )
        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # q, k, v = jnp.split(c_attn(X), 3, axis=2)
        q, k, v = q_proj(X), k_proj(X), v_proj(X)

        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k, q = apply_rotation(rot_embeds, k, q)  # BHTD

        k = repeat_kv(k, num_key_value_groups)
        v = repeat_kv(v, num_key_value_groups)

        att = (q @ jnp.swapaxes(k, -2, -1)) * (1.0 / math.sqrt(k.shape[-1]))

        # att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
        att = jnp.where(bias[:, :, :T, :T] == 0, float("-inf"), att)
        att = jax.nn.softmax(att, axis=-1)

        att = attn_dropout(att)

        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, -1
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)  # resid_dropout(c_proj(y))
        return y
