from __future__ import annotations

"""Attention-head feature extraction for vStream.

For each target token range (query tokens) and each source (a group of visual tokens),
we compute a 1152-dim feature vector (32 layers x 36 heads for Qwen3-VL-8B-Thinking).

Feature definition (readable version):
- Take the attention matrix A for each layer/head.
- For each query token in the target range, read attention paid to visual key tokens.
- Average over the query tokens.
- Sum/average visual tokens into sources using the source membership matrix.

This yields (num_targets, num_sources, num_layers*num_heads).
"""

from typing import Optional

import numpy as np
import torch


def features_from_attentions(
    *,
    attentions: tuple[torch.Tensor, ...],
    target_token_ranges: list[tuple[int, int]],
    vision_token_positions: list[int],
    source_membership: np.ndarray,
) -> torch.Tensor:
    if not attentions:
        raise ValueError("attentions is empty")
    if not target_token_ranges:
        raise ValueError("target_token_ranges cannot be empty")

    first = attentions[0]
    if first.ndim != 4:
        raise ValueError("attentions must have shape (batch, heads, seq, seq)")
    if int(first.shape[0]) != 1:
        raise ValueError("Only batch=1 is supported in the submission demo")

    num_layers = len(attentions)
    num_heads = int(first.shape[1])
    seq_len = int(first.shape[2])

    membership = torch.tensor(
        source_membership, device=first.device, dtype=torch.float32
    )
    num_sources = int(membership.shape[0])

    key_positions = torch.tensor(
        vision_token_positions, device=first.device, dtype=torch.long
    )
    if key_positions.numel() == 0:
        return torch.zeros(
            (len(target_token_ranges), num_sources, num_layers * num_heads),
            device=first.device,
        )

    # Convert target token ranges to the indices used by the attention matrix.
    # In causal LM, attention at position (t-1) is used when predicting token t.
    # So to get features for tokens [t0, t1), we read attention from positions [t0-1, t1-1).
    query_ranges: list[tuple[int, int]] = []
    for t0, t1 in target_token_ranges:
        if t0 < 1:
            raise ValueError("All targets must start at token index >= 1")
        q0, q1 = int(t0 - 1), int(t1 - 1)
        if q1 > seq_len:
            raise ValueError("Target ends beyond attention sequence length")
        query_ranges.append((q0, q1))

    out = torch.zeros(
        (len(target_token_ranges), num_sources, num_layers * num_heads),
        device=first.device,
        dtype=torch.float32,
    )

    for layer_idx, attn in enumerate(attentions):
        if attn.ndim != 4:
            raise ValueError("attentions must have shape (batch, heads, seq, seq)")
        if int(attn.shape[1]) != num_heads or int(attn.shape[2]) != seq_len:
            raise ValueError("All attention layers must share the same shape")

        # Remove batch dimension: (heads, seq, seq)
        a = attn[0]
        feat_off = layer_idx * num_heads

        for t_idx, (q0, q1) in enumerate(query_ranges):
            if q1 <= q0:
                continue
            # a[:, q0:q1, :] -> attention from target query tokens
            a_q = a[:, q0:q1, :]
            # Select only visual key tokens.
            a_vis = a_q.index_select(dim=2, index=key_positions)
            # Average over query tokens (dimension 1): (heads, num_vis)
            a_vis_mean = a_vis.mean(dim=1).float()
            # Pool visual tokens into sources: (heads, num_sources)
            a_src = a_vis_mean @ membership.T
            out[t_idx, :, feat_off : feat_off + num_heads] = a_src.transpose(0, 1)

    return out
