"""Lookback-Lens feature computation for hallucination detection.

Adapted from Lookback-Lens repository (EMNLP 2024).

Reference: "Detecting and Mitigating Contextual Hallucinations in Large Language Models
Using Only Attention Maps" - Chuang et al., 2024
"""

from __future__ import annotations

import torch
from torch import Tensor


def compute_lookback_lens(
    item_attn: list[Tensor] | Tensor,
    input_length: int,
) -> Tensor:
    """Compute mean lookback ratios for all layer-head pairs.

    Args:
        item_attn: Attention matrices [#layers, #heads, seq_len, seq_len]
            or list of [#heads, seq_len, seq_len] tensors.
        input_length: Boundary between context and generated tokens.

    Returns:
        Mean lookback ratios [#layers, #heads].

    Raises:
        ValueError: If there are no generated tokens.
    """
    per_token = compute_lookback_lens_per_token(item_attn, input_length)
    return per_token.mean(dim=-1)


def compute_lookback_lens_per_token(
    item_attn: list[Tensor] | Tensor,
    input_length: int,
) -> Tensor:
    """Compute lookback ratios for all layer-head pairs, per generated token.

    Args:
        item_attn: Attention matrices [#layers, #heads, seq_len, seq_len]
            or list of [#heads, seq_len, seq_len] tensors.
        input_length: Boundary between context and generated tokens.

    Returns:
        Lookback ratios [#layers, #heads, num_generated_tokens].

    Raises:
        ValueError: If there are no generated tokens or zero attention sum.
    """
    attn = torch.stack(item_attn) if isinstance(item_attn, list) else item_attn
    num_layers, num_heads, seq_len, _ = attn.shape

    if input_length >= seq_len:
        raise ValueError(
            f"No generated tokens: input_length ({input_length}) >= seq_len ({seq_len})"
        )

    gen_rows = attn[:, :, input_length:, :]  # [L, H, num_gen, seq_len]
    attn_ctx = gen_rows[:, :, :, :input_length].sum(dim=-1)  # [L, H, num_gen]
    attn_total = gen_rows.sum(dim=-1)  # [L, H, num_gen]

    if (attn_total == 0).any():
        zero_positions = (attn_total == 0).nonzero(as_tuple=True)
        layer, head, token = zero_positions[0][0], zero_positions[1][0], zero_positions[2][0]
        raise ValueError(
            f"Zero attention sum at layer {layer}, head {head}, position {input_length + token}: "
            "cannot compute lookback ratio"
        )

    return attn_ctx / attn_total


def compute_lookback_ratio_per_token(
    attention: Tensor,
    input_length: int,
) -> Tensor:
    """Compute lookback ratio for each generated token from a single head's attention.

    Lookback ratio = attn_on_context / (attn_on_context + attn_on_generated)

    Args:
        attention: Attention matrix [seq_len, seq_len], lower triangular (causal).
        input_length: Boundary between context and generated tokens.

    Returns:
        Lookback ratios for each generated token [num_generated_tokens].

    Raises:
        ValueError: If there are no generated tokens or zero attention sum.
    """
    seq_len = attention.shape[0]
    if input_length >= seq_len:
        raise ValueError(
            f"No generated tokens: input_length ({input_length}) >= seq_len ({seq_len})"
        )

    gen_rows = attention[input_length:]  # [num_gen, seq_len]
    attn_ctx = gen_rows[:, :input_length].sum(dim=-1)  # [num_gen]
    attn_total = gen_rows.sum(dim=-1)  # [num_gen]

    if (attn_total == 0).any():
        zero_pos = (attn_total == 0).nonzero(as_tuple=True)[0][0]
        raise ValueError(
            f"Zero attention sum at position {input_length + zero_pos}: "
            "cannot compute lookback ratio"
        )

    return attn_ctx / attn_total
