"""Analysis-only attention forward hooks to trigger vector analysis mid-run."""

from __future__ import annotations

from typing import Callable

from transformers.models.llama.modeling_llama import eager_attention_forward
from ..analysis.vector_analyzer import dist_QK_hist_from_states, plot_QK_hist, plot_QK_pca2d


def get_llama_forward_with_analysis(attn_module, *, layer_to_analyze: int = 10) -> Callable:
    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attn_module.head_dim)
        query_states = attn_module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = attn_module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = attn_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            query_states, key_states = attn_module.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if getattr(attn_module, "layer_idx", -1) == layer_to_analyze:
            dQ, dK, dQQ = dist_QK_hist_from_states(key_states, query_states, q_reduce="mean", diagonal=True, device_for_fit="cuda")
            fig = plot_QK_hist(dQ, dK, dQQ, title="LLM Q/K", bins=50)
            fig.savefig("./qk_hist.png", dpi=300, bbox_inches="tight")
            Q = query_states[0].reshape(-1, query_states.shape[-1]).detach().to("cpu")
            K = key_states[0].reshape(-1, key_states.shape[-1]).detach().to("cpu")
            fig2 = plot_QK_pca2d(Q, K, title="LLM Q/K", max_points_per_class=5000)
            fig2.savefig("./qk_pca2d.png", dpi=300, bbox_inches="tight")

        attention_interface = eager_attention_forward
        attn_output, attn_weights = attention_interface(
            attn_module,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not attn_module.training else attn_module.attention_dropout,
            scaling=attn_module.scaling,
            **kwargs,
        )
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = attn_module.o_proj(attn_output)
        return attn_output, attn_weights

    return custom_forward



