"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type, Optional
from functools import partial
import math
import torch
import torch.nn as nn

from latte_trans.config import LMTaskConfig


class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
        )
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        if position_ids is None:
            position_ids = torch.arange(0, x.shape[2], device=x.device).unsqueeze(0)

        # print(">>>>>> ", position_ids.shape)

        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        # print("****** ", inv_freq_expanded.shape)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = (
            device_type
            if isinstance(device_type, str) and device_type != "mps"
            else "cpu"
        )
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (
                inv_freq_expanded.float() @ position_ids_expanded.float()
            ).transpose(1, 2)
            # print("!!!! ", inv_freq_expanded.shape)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
            # print("++++ ", sin.shape)
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Gemma2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LMTaskConfig):
        super().__init__()
        self.config = config

        self.attention_dropout = config.dropout_att
        self.hidden_size = config.hidden_dim
        self.num_heads = config.nheads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.pos_embed_max_len
        self.rope_theta = 10000.0
        self.is_causal = True
        self.scaling = 256**-0.5  # config.query_pre_attn_scalar
        self.attn_logit_softcapping = 50.0

        if self.hidden_size % self.num_heads != 0:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.o_proj = nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
        )
        self.rotary_emb = Gemma2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
        )

        if self.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.attn_logit_softcapping

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output


class Gemma2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"


class GemmaMLP(nn.Module):

    def __init__(self, config: LMTaskConfig):
        super().__init__()
        embed_dim = config.hidden_dim
        inner_dim = config.intermediate_dim
        self.config = config
        self.gate_proj = nn.Linear(embed_dim, inner_dim, bias=False)
        self.down_proj = nn.Linear(inner_dim, embed_dim, bias=False)
        self.up_proj = nn.Linear(embed_dim, inner_dim, bias=False)

    def __call__(self, hidden_states):
        up_proj_states = self.up_proj(hidden_states)
        gate_states = nn.functional.gelu(
            self.gate_proj(hidden_states), approximate="tanh"
        )
        hidden_states = self.down_proj(up_proj_states * gate_states)
        return hidden_states


class GemmaDecoderLayer(nn.Module):

    def __init__(self, config: LMTaskConfig):
        super().__init__()
        self.config = config
        self.self_attn = Gemma2Attention(
            config=self.config
        )  # get_decoder_mixer(self.config, self.dtype)
        self.input_layernorm = Gemma2RMSNorm(self.config.hidden_dim)
        self.post_attention_layernorm = Gemma2RMSNorm(self.config.hidden_dim)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(self.config.hidden_dim)
        self.post_feedforward_layernorm = Gemma2RMSNorm(self.config.hidden_dim)
        self.mlp = GemmaMLP(self.config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        # residual connection
        hidden_states = residual + hidden_states

        return hidden_states


class Gemma2MLP(nn.Module):
    def __init__(self, config: LMTaskConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_dim
        self.intermediate_size = config.intermediate_dim
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

    def forward(self, x):
        res = nn.functional.gelu(self.gate_proj(x), approximate="tanh")
        return self.down_proj(res * self.up_proj(x))


class Gemma2DecoderLayer(nn.Module):
    def __init__(self, config: LMTaskConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_dim

        self.self_attn = Gemma2Attention(self.config)

        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(
            config.hidden_dim
        )  # , eps=config.rms_norm_eps)
        self.post_attention_layernorm = Gemma2RMSNorm(
            config.hidden_dim,
        )  # eps=config.rms_norm_eps)

        self.pre_feedforward_layernorm = Gemma2RMSNorm(
            config.hidden_dim
        )  # , eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(
            config.hidden_dim
        )  # , eps=config.rms_norm_eps
        # )
        self.sliding_window = 4096  # config.sliding_window

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        # if (
        #     self.config._attn_implementation != "flash_attention_2"
        #     and self.is_sliding
        #     and attention_mask is not None
        # ):  # efficient SDPA and no padding
        #     min_dtype = torch.finfo(hidden_states.dtype).min
        #     sliding_window_mask = torch.tril(
        #         torch.ones_like(attention_mask, dtype=torch.bool),
        #         diagonal=-self.sliding_window,
        #     )
        #     attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
        #     if attention_mask.shape[-1] <= 1:  # when decoding
        #         attention_mask = attention_mask[:, :, :, -self.sliding_window :]

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs


class GemmaDecoder(nn.Module):
    def __init__(self, config: LMTaskConfig):
        super().__init__()
        self.config = config
        self.norm = Gemma2RMSNorm(self.config.hidden_dim)
        # block_fn = partial(
        #     Gemma2DecoderLayer, self.config
        # )
        block_fn = partial(GemmaDecoderLayer, self.config)
        self.residual_block = nn.ModuleList(
            [block_fn() for i in range(self.config.nlayers)]
        )

    def forward(self, X, attention_mask):
        for layer in self.residual_block:
            X = layer(X, attention_mask=attention_mask)
        # get logits
        X = self.norm(X)
        return X
