import math
from typing import List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb, _flash_attention_forward


def llama_flashattention2_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        
        kv_len = key_states.size(-2)
        decode_evict = getattr(past_key_value, "decode_evict", True) if not isinstance(past_key_value, DynamicCache) else False
        if kv_len == q_len or (not decode_evict): # prefill w/ flashattn
            # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
            # to be able to avoid many of these transpose/reshape/view.
            query_states = query_states.transpose(1, 2)
            key_states = key_states.transpose(1, 2)
            value_states = value_states.transpose(1, 2)

            dropout_rate = self.attention_dropout if self.training else 0.0

            # In PEFT, usually we cast the layer norms in float32 for training stability reasons
            # therefore the input hidden states gets silently casted in float32. Hence, we need
            # cast them back in the correct dtype just to be sure everything works as expected.
            # This might slowdown training & inference so it is recommended to not cast the LayerNorms
            # in fp32. (LlamaRMSNorm handles it correctly)

            input_dtype = query_states.dtype
            if input_dtype == torch.float32:
                if torch.is_autocast_enabled():
                    target_dtype = torch.get_autocast_gpu_dtype()
                # Handle the case where the model is quantized
                elif hasattr(self.config, "_pre_quantization_dtype"):
                    target_dtype = self.config._pre_quantization_dtype
                else:
                    target_dtype = self.q_proj.weight.dtype

                logger.warning_once(
                    f"The input hidden states seems to be silently casted in float32, this might be related to"
                    f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                    f" {target_dtype}."
                )

                query_states = query_states.to(target_dtype)
                key_states = key_states.to(target_dtype)
                value_states = value_states.to(target_dtype)

            attn_output = _flash_attention_forward(
                query_states,
                key_states,
                value_states,
                attention_mask,
                q_len,
                position_ids=position_ids,
                dropout=dropout_rate,
                sliding_window=getattr(self, "sliding_window", None),
                use_top_left_mask=self._flash_attn_uses_top_left_mask,
                is_causal=self.is_causal,
                **kwargs,
            )
            if not isinstance(past_key_value, DynamicCache) and kv_len == q_len:
                past_key_value.score_tracker.update(
                    A  = None,
                    V  = past_key_value.value_cache[self.layer_idx].detach(),
                    qK = None,
                    O  = attn_output.transpose(1, 2).contiguous().detach() if past_key_value.score_tracker.use_k_score else None, 
                    Q  = query_states.transpose(1, 2).contiguous().detach(),
                    K  = key_states.transpose(1, 2).contiguous().detach(),
                    layer_idx = self.layer_idx,                
                )
                past_key_value.evict(self.layer_idx)

        else: # decoding w/ eager attn (dynamic pruning cache)
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
            
            ## 1. additionally store pre-softmax attn_weights
            if not isinstance(past_key_value, DynamicCache):
                if past_key_value.score_tracker.use_k_score:
                    qK = attn_weights.detach()  # [bsz, num_heads, q_len, kv_len]
                else:
                    qK = None
            
            if attention_mask is not None:  # no matter the length, we just slice it
                causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
                attn_weights = attn_weights + causal_mask

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
            attn_output = torch.matmul(attn_weights, value_states)

            ## 2. additionally store pre-o-proj attn_output
            if not isinstance(past_key_value, DynamicCache):
                if past_key_value.score_tracker.use_k_score:
                    O = attn_output.detach()   # [bsz, num_heads, q_len, head_dim]
                else:
                    O = None

            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )

            attn_output = attn_output.transpose(1, 2).contiguous()

            if not isinstance(past_key_value, DynamicCache):
                past_key_value.score_tracker.update(
                    A  = attn_weights.detach(),
                    V  = past_key_value.value_cache[self.layer_idx].detach(),
                    qK = qK,
                    O  = O, 
                    layer_idx = self.layer_idx,                
                )    
                past_key_value.evict(self.layer_idx)

            if not output_attentions:
                attn_weights = None
        
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


def llama_attention_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        ## 1. additionally store pre-softmax attn_weights
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                qK = attn_weights.detach()  # [bsz, num_heads, q_len, kv_len]
            else:
                qK = None
        
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        ## 2. additionally store pre-o-proj attn_output
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                O = attn_output.detach()   # [bsz, num_heads, q_len, head_dim]
            else:
                O = None

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not isinstance(past_key_value, DynamicCache):
            past_key_value.score_tracker.update(
                 A  = attn_weights.detach(),
                 V  = past_key_value.value_cache[self.layer_idx].detach(),
                 qK = qK,
                 O  = O, 
                 layer_idx = self.layer_idx,                
            )
            kv_len = key_states.size(-2)
            decode_evict = getattr(past_key_value, "decode_evict", True)
            if (kv_len == q_len) or decode_evict:
                past_key_value.evict(self.layer_idx)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


def llama_attention_streaming_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        ## 1. additionally store pre-softmax attn_weights
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                qK = attn_weights.detach()  # [bsz, num_heads, q_len, kv_len]
            else:
                qK = None
        
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        ## 2. additionally store pre-o-proj attn_output
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                O = attn_output.detach()   # [bsz, num_heads, q_len, head_dim]
            else:
                O = None

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not isinstance(past_key_value, DynamicCache):
            past_key_value.score_tracker.update(
                 A  = attn_weights.detach(),
                 V  = past_key_value.value_cache[self.layer_idx].detach(),
                 qK = qK,
                 O  = O, 
                 layer_idx = self.layer_idx,                
            )

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


def llama_attention_forward_autodiff(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        ## 1. additionally store pre-softmax attn_weights
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                qK = attn_weights.detach()
            else:
                qK = None
        
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        ## 2. additionally store pre-o-proj attn_output
        if not isinstance(past_key_value, DynamicCache):
            if past_key_value.score_tracker.use_k_score:
                O = attn_output.detach()   # [bsz, num_heads, q_len, head_dim]
            else:
                O = None

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not isinstance(past_key_value, DynamicCache):
            past_key_value.score_tracker.update(
                 A  = attn_weights.detach(),
                 V  = past_key_value.value_cache[self.layer_idx].detach(),
                 qK = qK,
                 O  = O, 
                 layer_idx = self.layer_idx,                
            )

            if self.layer_idx == 1:

                verifier = AutoDiffScoreVerifier()
                S_value_ad, S_key_ad, S_cross_ad, S_value_form, S_key_form, S_cross_form = verifier.verify(
                    Q = query_states.detach().clone(),   
                    K = key_states.detach().clone(),     
                    V = value_states.detach().clone(),
                    attn_mask=attention_mask[:, :, :, : key_states.shape[-2]]
                )

                res1 = AutoDiffResults(
                    S_value_ad=S_value_ad,
                    S_key_ad=S_key_ad,
                    S_cross_ad=S_cross_ad,
                    S_value_formula=S_value_form,
                    S_key_formula=S_key_form,
                    S_cross_formula=S_cross_form
                )
                
                verifier.report(res1, name=f"Layer {self.layer_idx} AutoDiff vs Formula", verbose=True)

                v_score = past_key_value.score_tracker.all_a_normp[self.layer_idx] * past_key_value.score_tracker.all_v_normp[self.layer_idx]
                k_score = past_key_value.score_tracker.all_k_scores[self.layer_idx]
                if past_key_value.score_tracker.use_k_score \
                    and past_key_value.score_tracker.use_v_score \
                    and past_key_value.score_tracker.use_cross:
                            S_key_ad = S_key_ad + S_cross_ad
                print()
                res2 = AutoDiffResults(
                    S_value_ad=S_value_ad,
                    S_key_ad=S_key_ad,
                    S_cross_ad=None,
                    S_value_formula=v_score,
                    S_key_formula=k_score,
                    S_cross_formula=None
                )
                verifier.report(res2, name=f"Layer {self.layer_idx} AutoDiff vs Implementation", verbose=True)

            kv_len = key_states.size(-2)
            decode_evict = getattr(past_key_value, "decode_evict", True)
            if (kv_len == q_len) or decode_evict:
                past_key_value.evict(self.layer_idx)

            # if self.layer_idx == 3:
            #     1/0
        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value




import math
from dataclasses import dataclass

@dataclass
class AutoDiffResults:
    S_value_ad: torch.Tensor
    S_key_ad: torch.Tensor
    S_cross_ad: torch.Tensor
    S_value_formula: torch.Tensor
    S_key_formula: torch.Tensor
    S_cross_formula: torch.Tensor

class AutoDiffScoreVerifier:
    """
    Verifies OBD scores via directional second derivatives using PyTorch autograd.
      Q,K,V: [bsz, head, seq_len, head_dim]
      A: [bsz, head, q_len, kv_len]
      O: [bsz, head, q_len, head_dim]
    verify per (bsz, head, pruning_idx p). Vectorized over (bsz, head).
    """
    def __init__(self):
        pass

    @staticmethod
    def _softmax_logits(Q, K, dk, attn_mask=None):
        logits = torch.matmul(Q, K.transpose(-2, -1))
        if dk:
            logits = logits / math.sqrt(dk)
        if attn_mask is not None:
            logits = logits + attn_mask[..., :K.size(-2)]
        A = torch.softmax(logits, dim=-1)
        return A

    def _make_O(self, Q, K, V, dk, attn_mask=None):
        A = self._softmax_logits(Q, K, dk, attn_mask)
        O = torch.matmul(A, V)
        return O, A

    def _directional_second_derivative_per_head(self, E_bh, vars_tuple, dirs_tuple):
        """
        E_bh: tensor [B, H]  (per-(batch, head) scalar objective)
        returns d2_map: [B, H] where each entry is the directional second derivative
                        d^2/dε^2 E_bh[b,h](vars + ε*dir) |_{ε=0}
        """
        B, H = E_bh.shape
        device = E_bh.device
        dtype  = E_bh.dtype
        d2_map = torch.empty(B, H, device=device, dtype=dtype)

        for b in range(B):
            for h in range(H):
                go = torch.zeros_like(E_bh)
                go[b, h] = 1.0

                grad1 = torch.autograd.grad(
                    E_bh, vars_tuple, grad_outputs=go, create_graph=True, retain_graph=True
                )
                phi_bh = 0.0
                for g, d in zip(grad1, dirs_tuple):
                    if g is None: continue
                    phi_bh = phi_bh + (g * d).sum()

                grad2 = torch.autograd.grad(phi_bh, vars_tuple, retain_graph=True)
                d2_bh = 0.0
                for g2, d in zip(grad2, dirs_tuple):
                    if g2 is None: continue
                    d2_bh = d2_bh + (g2 * d).sum()

                d2_map[b, h] = d2_bh
        
        return d2_map

    def verify(self, Q, K, V, attn_mask=None):
        B, H, Tq, Dk = Q.shape
        Tk = K.size(-2)
        Dv = V.size(-1)

        O_base, A_base = self._make_O(Q, K, V, Dk, attn_mask)
        O_target = O_base.detach()

        def E_bh_of(K_, V_):
            O_star, _ = self._make_O(Q, K_, V_, Dk, attn_mask)
            # keep per-(B,H) ptb
            return ((O_star - O_target) ** 2).sum(dim=(-2, -1))

        S_value_ad   = torch.zeros(B, H, Tk, device=K.device, dtype=K.dtype)
        S_key_ad     = torch.zeros_like(S_value_ad)
        S_cross_ad   = torch.zeros_like(S_value_ad)
        S_value_form = torch.zeros_like(S_value_ad)
        S_key_form   = torch.zeros_like(S_value_ad)
        S_cross_form = torch.zeros_like(S_value_ad)

        K = K.clone().requires_grad_(True)
        V = V.clone().requires_grad_(True)

        with torch.no_grad():
            A = A_base.detach()
            O = O_base.detach()
            Z = (Q @ K.detach().transpose(-2, -1)) / math.sqrt(Dk)

        eye_Tk = torch.eye(Tk, device=K.device, dtype=K.dtype)
        for p in range(Tk):
            row_mask = eye_Tk[p].view(1,1,Tk,1)

            dV = (-V.detach()) * row_mask
            dK = (-K.detach()) * row_mask

            # Value curvature (per head): directions (0, dV)
            E_bh = E_bh_of(K, V)  # [B,H]
            d2_val_map = self._directional_second_derivative_per_head(
                E_bh, (K, V), (torch.zeros_like(K), dV)
            )
            S_value_ad[..., p] = 0.5 * d2_val_map

            # Key curvature (per head): directions (dK, 0)
            d2_key_map = self._directional_second_derivative_per_head(
                E_bh, (K, V), (dK, torch.zeros_like(V))
            )
            S_key_ad[..., p] = 0.5 * d2_key_map

            # Both curvature for cross
            d2_both_map = self._directional_second_derivative_per_head(
                E_bh, (K, V), (dK, dV)
            )
            S_cross_ad[..., p] = 0.5 * (d2_both_map - d2_key_map - d2_val_map)

            # ===== Closed-form checks =====
            A_col_p = A[..., :, p]                      
            V_row_p = V.detach()[..., p, :]             
            Z_ip    = Z[..., :, p]                      
            # diff    = V_row_p.unsqueeze(-2) - O 
            diff2 = V_row_p.unsqueeze(-2)**2 + O**2 - 2.0 * V_row_p.unsqueeze(-2) * O

            S_value_form[..., p] = (A_col_p**2).sum(-1) * (V_row_p**2).sum(-1)
            S_key_form[..., p]   = ((A_col_p**2) * (Z_ip**2) * diff2.sum(-1)).sum(-1)
            V_norm2 = (V_row_p**2).sum(-1, keepdim=True)
            V_dot_O = (V_row_p.unsqueeze(-2) * O).sum(-1)
            S_cross_form[..., p] = 2.0 * ((A_col_p**2) * Z_ip * (V_norm2 - V_dot_O)).sum(-1)

        return S_value_ad, S_key_ad, S_cross_ad, S_value_form, S_key_form, S_cross_form
    
    @staticmethod
    def report(res: AutoDiffResults, eps: float = 1e-6, name: str = "Scores", verbose: bool = True):
        def rel_err(a, b):
            denom = b.abs().clamp_min(eps)
            return ( (a - b).abs() / denom ).nan_to_num()
        
        if verbose:
            print("S_value_autodiff: ", res.S_value_ad[0, 0], "\nS_value_formula: ", res.S_value_formula[0, 0])
            print("S_key_autodiff: ", res.S_key_ad[0, 0], "\nS_key_formula: ", res.S_key_formula[0, 0])
            if res.S_cross_formula is not None:
                print("S_cross_autodiff: ", res.S_cross_ad[0, 0], "\nS_cross_formula: ", res.S_cross_formula[0, 0])

        rv = rel_err(res.S_value_ad, res.S_value_formula)
        rk = rel_err(res.S_key_ad,   res.S_key_formula)
        
        print(f"[{name}] relerr value: mean={rv.mean().item():.3e}, max={rv.max().item():.3e}")
        print(f"[{name}] relerr key  : mean={rk.mean().item():.3e}, max={rk.max().item():.3e}")


        if res.S_cross_formula is not None:
            rc = rel_err(res.S_cross_ad, res.S_cross_formula)
            print(f"[{name}] relerr cross: mean={rc.mean().item():.3e}, max={rc.max().item():.3e}")