from typing import TypedDict

import torch
from torch import Tensor


class InternalStatesDict(TypedDict):
    input_lengths: Tensor
    attn_metrics: list[dict[str, Tensor]]
    hs_metrics: list[dict[str, Tensor]]
    generated_tokens: list[Tensor]


class SinkScoresDict(TypedDict):
    top_k: int
    input_lengths: Tensor
    sink_scores_per_token_top_k: Tensor
    sink_sorted_idx_top_k: Tensor


def compute_sink_scores_from_laplacian_and_attention_diags(
    data: InternalStatesDict,
) -> list[Tensor]:
    """Compute sink scores per token."""
    sink_scores = []
    for item in data["attn_metrics"]:
        ss = item["laplacian_diags"] + item["attn_diags"]
        sink_scores.append(ss)
    return sink_scores


def compute_topk_sink_scores_from_laplacian_and_attention_diags(
    data: InternalStatesDict,
    top_k: int,
    max_layer: int | None = None,
) -> SinkScoresDict:
    """Compute sorted sink scores per token if not already present."""
    features = []
    sorted_idx = []
    for item in data["attn_metrics"]:
        if "sink_score_per_token" not in item:
            combined = item["laplacian_diags"] + item["attn_diags"]
            item["sink_score_per_token"], item["sink_score_per_token_sorted_idx"] = torch.sort(
                combined, dim=-1, descending=True
            )
        features.append(item["sink_score_per_token"][:max_layer, :, :top_k])
        sorted_idx.append(item["sink_score_per_token_sorted_idx"][:max_layer, :, :top_k])

    return {
        "top_k": top_k,
        "input_lengths": data["input_lengths"],
        "sink_scores_per_token_top_k": torch.stack(features),
        "sink_sorted_idx_top_k": torch.stack(sorted_idx),
    }


def compute_sink_score_per_token_from_attention_matrix(
    attn_matrix: Tensor,
) -> Tensor:
    """Compute sink scores from the attention matrix.
    attn_matrix dimensions: [#layers, #heads, #tokens, #tokens]
    output dimensions: [#layers, #heads, #tokens]
    """
    num_tokens = attn_matrix.size(-2)
    denom = torch.flip(torch.arange(1, num_tokens + 1, device=attn_matrix.device), dims=[-1])
    return attn_matrix.sum(dim=-2) / denom


def compute_topk_sink_scores_per_token_from_attention_matrix(
    attn_matrix: Tensor,
    top_k: int,
) -> SinkScoresDict:
    """Compute top-k sink scores per token from the attention matrix.
    attn_matrix dimensions: [#layers, #heads, #tokens, #tokens]
    output dimensions: [#layers, #heads, #top_k]
    """
    sink_scores = compute_sink_score_per_token_from_attention_matrix(attn_matrix)
    sink_scores_sorted, sink_scores_argsort = torch.sort(sink_scores, dim=-1, descending=True)
    return {
        "top_k": top_k,
        "sink_scores_per_token_top_k": sink_scores_sorted[:, :, :top_k],
        "sink_sorted_idx_top_k": sink_scores_argsort[:, :, :top_k],
    }
