from functools import partial
from torch.nn import Dropout
import torch
import torch.nn as nn
from typing import Optional
from lxt.efficient.patches import (
    patch_method,
    rms_norm_forward,
    gated_mlp_forward,
    dropout_forward,
    check_already_patched,
    wrap_attention_forward,
    patch_attention,
)
from lxt.efficient.models.gemma3 import gemma3_norm
from transformers.models.llama import modeling_llama
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, repeat_kv
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2RMSNorm
from transformers.models.gemma3 import modeling_gemma3
from transformers.models.gemma3.modeling_gemma3 import Gemma3MLP, Gemma3RMSNorm
from transformers.integrations.flash_attention import flash_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

IxG_map = {Dropout: partial(patch_method, dropout_forward)}


def eager_attention_forward_with_attn_weights(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

    ####
    module.attn_weights = attn_weights
    attn_weights.retain_grad() if attn_weights.requires_grad else None
    ####

    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    ####
    module.attn_output = attn_output
    attn_output.retain_grad() if attn_output.requires_grad else None
    ####

    return attn_output, attn_weights


def patch_attention_with_attn_weights(module):
    new_forward = wrap_attention_forward(eager_attention_forward_with_attn_weights)
    if check_already_patched(module.eager_attention_forward, new_forward):
        return False
    else:
        module.eager_attention_forward = new_forward
    return True


def get_llama_eager():
    return {
        LlamaMLP: partial(patch_method, gated_mlp_forward),
        LlamaRMSNorm: partial(patch_method, rms_norm_forward),
        Dropout: partial(patch_method, dropout_forward),
        modeling_llama: patch_attention_with_attn_weights,
    }


def get_qwen_eager():
    return {
        Qwen2MLP: partial(patch_method, gated_mlp_forward),
        Qwen2RMSNorm: partial(patch_method, rms_norm_forward),
        Dropout: partial(patch_method, dropout_forward),
        modeling_qwen2: patch_attention_with_attn_weights,
    }


def get_gemma_eager():
    return {
        Gemma3MLP: partial(patch_method, gated_mlp_forward),
        Gemma3RMSNorm: partial(patch_method, gemma3_norm, method_name="_norm"),
        Dropout: partial(patch_method, dropout_forward),
        modeling_gemma3: patch_attention,
    }


def flash_attention_forward_with_output(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
):
    # Run the original FA forward
    attn_output, _ = flash_attention_forward(
        module,
        query,
        key,
        value,
        attention_mask,
        dropout=dropout,
        scaling=scaling,
        sliding_window=sliding_window,
        softcap=softcap,
        **kwargs,
    )

    ####
    module.attn_output = attn_output
    attn_output.retain_grad() if attn_output.requires_grad else None
    ####

    return attn_output, None


def get_llama_flash():
    ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", flash_attention_forward_with_output)
    return {
        LlamaMLP: partial(patch_method, gated_mlp_forward),
        LlamaRMSNorm: partial(patch_method, rms_norm_forward),
        Dropout: partial(patch_method, dropout_forward),
        modeling_llama: patch_attention,
    }


def get_qwen_flash():
    ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", flash_attention_forward_with_output)
    return {
        Qwen2MLP: partial(patch_method, gated_mlp_forward),
        Qwen2RMSNorm: partial(patch_method, rms_norm_forward),
        Dropout: partial(patch_method, dropout_forward),
        modeling_qwen2: patch_attention,
    }


def get_gemma_flash():
    ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", flash_attention_forward_with_output)
    return {
        Gemma3MLP: partial(patch_method, gated_mlp_forward),
        Gemma3RMSNorm: partial(patch_method, gemma3_norm, method_name="_norm"),
        Dropout: partial(patch_method, dropout_forward),
        modeling_gemma3: patch_attention,
    }
