from typing import Callable, List, Optional, Tuple, Union
import math
import torch
import torch.nn as nn
from fancy_einsum import einsum
import einops
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaMLP,
    apply_rotary_pos_emb,
    repeat_kv,
)
from torch.nn.attention import sdpa_kernel, SDPBackend
from transformers.integrations.sdpa_attention import use_gqa_in_sdpa


class HookPoint(nn.Module):

    def __init__(self):
        super().__init__()
        self.name = None

    def forward(self, x):
        return x


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3))
    attn_weights = module.hook_qk_logits(attn_weights)
    attn_weights = attn_weights * scaling
    #attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].to(
            attn_weights.device
        )
        attn_weights = attn_weights + causal_mask

    #attn_weights = module.hook_qk_logits(attn_weights)
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    # value_states.shape: [batch, heads, seq, head_dim]
    value_states = module.hook_value_states_post_attn(value_states)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(
    query,
    key,
    value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    enable_gqa=False,
) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value


def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
        logger.warning_once(
            "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
            " Please set your attention to `eager` if you want any of these features."
        )
    sdpa_kwargs = {}
    if hasattr(module, "num_key_value_groups"):
        if not use_gqa_in_sdpa(attention_mask, key):
            key = repeat_kv(key, module.num_key_value_groups)
            value = repeat_kv(value, module.num_key_value_groups)
        else:
            # Goes here.
            sdpa_kwargs = {"enable_gqa": True}

    if attention_mask is not None and attention_mask.ndim == 4:
        attention_mask = attention_mask[:, :, :, : key.shape[-2]]

    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
    # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
    if is_causal is None:
        # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
        # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
        is_causal = (
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )

    # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
    # We convert it to a bool for the SDPA kernel that only accepts bools.
    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    # attn_output = torch.nn.functional.scaled_dot_product_attention(
    # breakpoint()
    # attn_output = scaled_dot_product_attention(
    #    query.float(),
    #    key.float(),
    #    value.float(),
    #    attn_mask=attention_mask,
    #    dropout_p=dropout,
    #    scale=scaling,
    #    is_causal=is_causal,
    #    **sdpa_kwargs,
    # )
    with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=attention_mask,
            dropout_p=dropout,
            scale=scaling,
            is_causal=is_causal,
            **sdpa_kwargs,
        )
    # breakpoint()
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None


def hooked_forward_attention(
    self,
    hidden_states: torch.Tensor,
    position_embeddings: Tuple[torch.Tensor, torch.Tensor],
    attention_mask: Optional[torch.Tensor],
    past_key_value=None,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.head_dim)

    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings
    query_states = self.hook_query_states_pre_rope(query_states)
    key_states = self.hook_key_states_pre_rope(key_states)

    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(
            key_states, value_states, self.layer_idx, cache_kwargs
        )

    key_states = self.hook_key_states(key_states)
    query_states = self.hook_query_states(query_states)
    attention_interface: Callable = eager_attention_forward
    if self.config._attn_implementation != "eager":
        attention_interface = sdpa_attention_forward

    attn_output, attn_weights = attention_interface(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        dropout=0.0 if not self.training else self.attention_dropout,
        scaling=self.scaling,
        **kwargs,
    )

    B, H, S, D = query_states.shape

    attn_weights = self.hook_attn_pattern(attn_weights)

    attn4d = attn_output  # must be the tensor *before* flatten+o_proj

    # Ensure attn4d is [B, S, H, D]
    if attn4d.shape[1] == H:
        attn4d = attn4d.transpose(1, 2) # -> [B, S, H, D]

    _B, _S, _H, _D = attn4d.shape
    assert H == _H and D == _D
    assert self.o_proj.weight.shape[1] == H * D

    # W: [d_model_out, H*D] -> [H, D, d_model_out]
    W = self.o_proj.weight
    W_per_head = W.view(-1, H, D).permute(1, 2, 0).contiguous()  # [H, D, d_model_out]

    # per-head contributions in d_model space: [B, S, H, d_model_out]
    per_head = torch.einsum("b s h d, h d m -> b s h m", attn4d, W_per_head)

    # sum of head contributions
    out = per_head.sum(dim=2)  # [B, S, d_model_out]

    # add bias if it exists (often None in LLaMA)
    if self.o_proj.bias is not None:
        out = out + self.o_proj.bias

    # attn_weights: [batch, heads, seq (query), seq (key)]
    # attn_output: [batch, seq, heads, head_dim]
    # attn_output_reshaped: [batch, seq, d_model (heads * head_dim)]
    #attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous()

    W_O = self.o_proj.weight #.clone()
    ## [heads, d_head, d_model]
    W_O = einops.rearrange(W_O, "m (n h)->n h m", n=self.config.num_attention_heads)
    ## self.o_proj: [d_model, d_model]
    #attn_output_final = self.hook_o_proj(self.o_proj(attn_output_reshaped))
    attn_output_per_head = einsum(
        "batch seq heads d_head, heads d_head d_model -> batch seq heads d_model",
        attn_output,
        W_O,
    )
    ## [batch seq n_heads d_model]
    attn_output_per_head = self.hook_attn_out_per_head(attn_output_per_head)
    attn_output_final = attn_output_per_head.sum(dim=2)
    #return attn_output_final, attn_weights

    attn_output2 = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output2 = self.o_proj(attn_output2)
    #breakpoint()
    return attn_output2, attn_weights


def hooked_forward_mlp(self, x):
    self.mlp_mid = self.hook_mlp_mid(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    down_proj = self.down_proj(self.mlp_mid)
    return down_proj


def hooked_forward_decoder_layer(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value=None,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[
        Tuple[torch.Tensor, torch.Tensor]
    ] = None,  # necessary, but kept here for BC
    **kwargs,
) -> torch.Tensor:

    residual = hidden_states
    hidden_states = self.hook_resid_pre(hidden_states)
    hidden_states = self.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, _ = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
        **kwargs,
    )
    hidden_states = residual + hidden_states.to(residual.device)
    hidden_states = self.hook_resid_mid(hidden_states)

    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states.to(residual.device)
    hidden_states = self.hook_resid_post(hidden_states)
    return hidden_states


def _convert_to_hooked_model(module):
    for child in module.children():

        if isinstance(child, LlamaAttention):
            child.forward = hooked_forward_attention.__get__(child, LlamaAttention)

        if isinstance(child, LlamaMLP):
            child.forward = hooked_forward_mlp.__get__(child, LlamaMLP)

        if isinstance(child, LlamaDecoderLayer):
            child.forward = hooked_forward_decoder_layer.__get__(
                child, LlamaDecoderLayer
            )

        _convert_to_hooked_model(child)


def convert_to_hooked_model_llama(model):
    """
    This function sets up a hook for the model's forward pass.
    """
    for layer in model.model.layers:
        layer.hook_resid_pre = HookPoint()
        layer.hook_resid_mid = HookPoint()
        layer.hook_resid_post = HookPoint()

        layer.self_attn.hook_key_states_pre_rope = HookPoint()
        layer.self_attn.hook_query_states_pre_rope = HookPoint()

        layer.self_attn.hook_key_states = HookPoint()
        layer.self_attn.hook_query_states = HookPoint()

        layer.self_attn.hook_attn_pattern = HookPoint()
        layer.self_attn.hook_qk_logits = HookPoint()
        layer.self_attn.hook_value_states_post_attn = HookPoint()
        layer.self_attn.hook_o_proj = HookPoint()
        layer.self_attn.hook_attn_out_per_head = HookPoint()

        layer.mlp.hook_mlp_mid = HookPoint()

    _convert_to_hooked_model(model)
