import multiprocessing as mp
import os
from pathlib import Path
from typing import Generator, cast

import psutil
import torch
from loguru import logger
from torch import Tensor
from tqdm.auto import tqdm
from transformers import AutoTokenizer, TokenizersBackend

from hallucinations.dirs import DatasetDir
from hallucinations.features.lookback_lens import compute_lookback_lens
from hallucinations.features.mtopdiv import compute_mtopdiv
from hallucinations.features.processing import remove_padding_from_intermediate_states
from hallucinations.features.sink_scores import compute_sink_score_per_token_from_attention_matrix
from hallucinations.utils.misc import load_and_resolve_config

DEFAULT_VERTICAL_EDGE_WEIGHT = 1.0


def attention_diagonal(item_attn: list[Tensor] | Tensor) -> Tensor:
    """Computes attention diagonal for single example from dataset.
    Input shape of item_attn is [#layers, [#heads x seq_length x seq_length]]
    Output shape is [#heads x (#layers * seq_length)]
    """
    if isinstance(item_attn, Tensor):
        return torch.diagonal(item_attn, dim1=-2, dim2=-1).clone()
    else:
        return torch.stack(
            [torch.diagonal(layer_attn, dim1=-2, dim2=-1) for layer_attn in item_attn]
        )


def random_walk_laplacian_diagonal_from_attn(
    item_attn: list[Tensor],
) -> Tensor:
    """Computes random walk laplacian diagonal for single example from dataset.
    Input shape of item_attn is [#layers, [#heads x seq_length x seq_length]]
    Output shape is [#heads x (#layers * seq_length)]
    """
    per_layer_laplacian_diags = []
    for layer_attn in item_attn:
        # take diagonal of the attention matrix a_ii
        nom = torch.diagonal(layer_attn, offset=0, dim1=1, dim2=2)
        # compute out-degree d_ii
        denom = layer_attn.sum(dim=1)  #

        # computre L_rw = 1 - a_ii / d_ii
        laplacian_diag = 1 - nom / denom
        per_layer_laplacian_diags.append(laplacian_diag)

    laplacian_diags = torch.stack(per_layer_laplacian_diags)
    return laplacian_diags


def laplacian_diagonal_from_attn(
    item_attn: list[Tensor] | Tensor,
    vertical_edges: bool,
    vertical_edge_weight: float | None = None,
) -> Tensor:
    """Computes laplacian diagonal for single example from dataset.
    Input shape of item_attn is [#layers, [#heads x seq_length x seq_length]]
    Output shape is [#heads x (#layers * seq_length)]
    """
    device = item_attn[0].device
    if vertical_edges:
        assert vertical_edge_weight is not None
    # we treat attention matrix as weighted adjacency matrix
    # to obtain the laplacian we need to substract diagonal degree matrix from the adjacency matrix
    # I guess, we can ignore self-loops and use diagonal degree matrix only
    # to account for vertical edges, we add one to layers from the second onwards
    fst_layer_attn = item_attn[0]
    fst_nom = fst_layer_attn.sum(dim=1)
    fst_denom = torch.arange(1, fst_layer_attn.size(1) + 1, device=device).flip(dims=[0])
    # D := weighted out-degree
    fst_weighted_degree = fst_nom / fst_denom
    # L := D - A
    fst_lap = fst_weighted_degree - torch.diagonal(fst_layer_attn, offset=0, dim1=1, dim2=2)

    per_layer_laplacian_diags = [fst_lap]
    for layer_attn in item_attn[1:]:
        # per-layer weighted out-degree
        # for vertical edges, we set weight to constant
        if vertical_edges:
            assert vertical_edge_weight is not None
            nom = layer_attn.sum(dim=1) + vertical_edge_weight
            denom = torch.arange(1, layer_attn.size(1) + 1, device=device).flip(dims=[0]) + 1
        else:
            nom = layer_attn.sum(dim=1)
            denom = torch.arange(1, layer_attn.size(1) + 1, device=device).flip(dims=[0])

        layer_weighted_degree = nom / denom
        layer_lap_diag = layer_weighted_degree - torch.diagonal(
            layer_attn, offset=0, dim1=1, dim2=2
        )
        per_layer_laplacian_diags.append(layer_lap_diag)

    laplacian_diags = torch.stack(per_layer_laplacian_diags)

    return laplacian_diags


def log_det_attnn_over_dataset(attn_scores: list[list[Tensor]]) -> list[Tensor]:
    """Computes Attention Score (non-aggregated) over the dataset.
    Dimensions of the input are [#examples, #layers, [#heads x sequence_length x sequence_length]].
    Dimensions of the output are [#examples, [#layers x #heads]].
    """
    log_dets = []
    for example_attn in tqdm(attn_scores, desc="log-det(atnn)", leave=False):
        log_dets.append(log_det_attn(example_attn))
    return log_dets


def log_det_attn(attn_scores: list[Tensor]) -> Tensor:
    """Computes log-det(attn) for a single example.

    AttnScore was proposed in  https://openreview.net/forum?id=LYx4w3CAgy,
    Implementation: https://github.com/GaurangSriramanan/LLM_Check_Hallucination_Detection/blob/2f3bf9ea6db19e60a416090f77694816c92a9146/common_utils.py#L319

    Attention Score is defined as:
    log(det(A)) = mean(log(diag(A))), where A is the attention matrix.
    In the original paper, the scores were additionally summed over heads and one layer was used.

    Dimensions of the input are [#layers, [num_heads x sequence_length x sequence_length]].
    Dimensions of the output Tensor is [#layers x #heads].
    """
    per_example_log_dets = []
    for layer_attn in attn_scores:
        per_head_log_det = torch.diagonal(layer_attn, dim1=1, dim2=2).log().mean(dim=1)
        per_example_log_dets.append(per_head_log_det)
    return torch.stack(per_example_log_dets)


def yield_stacked_attentions(
    dataset_dir: Path | DatasetDir,
    attentions_dir: Path | None = None,
    remove_padding: bool = False,
) -> Generator[list[list[Tensor]] | list[Tensor], None, None]:
    """Yields stacked attention scores without padding for all shards in the dataset.
    Dimensions of each shard is [#examples, [#layers, [num_heads x sequence_length x sequence_length]]].
    """
    ds_dir = DatasetDir(dataset_dir) if isinstance(dataset_dir, Path) else dataset_dir

    if remove_padding:
        config = load_and_resolve_config(ds_dir.config_file)
        tokenizer = AutoTokenizer.from_pretrained(config["llm"]["name"])
        tokenizer = cast(TokenizersBackend, tokenizer)
    else:
        tokenizer = None

    if attentions_dir is None:
        data_shards = list(ds_dir.attentions_dir.glob("*.pt"))
    else:
        data_shards = list(attentions_dir.glob("*.pt"))

    return yield_stacked_attentions_from_shard_list(data_shards, remove_padding, tokenizer)


def yield_stacked_attentions_from_shard_list(
    data_shards: list[Path],
    remove_padding: bool = False,
    tokenizer: TokenizersBackend | None = None,
) -> Generator[list[list[Tensor]] | list[Tensor], None, None]:
    process = psutil.Process(os.getpid())

    with tqdm(data_shards, desc="Loading attentions", total=len(data_shards)) as pbar:
        for shard_file in pbar:
            memory_info = process.memory_info()
            pbar.set_postfix(dict(memory=f"{memory_info.rss / 1024**3:.2f} GB"))

            stacked_attn_scores, generated_tokens = load_and_stack_attentions_shard(shard_file)
            if remove_padding:
                attn_scores_without_padding = remove_padding_from_intermediate_states(
                    per_layer_batched_data=stacked_attn_scores,
                    data_type="attn",
                    generated_tokens=generated_tokens,
                    tokenizer=tokenizer,
                )
                yield attn_scores_without_padding
            else:
                yield stacked_attn_scores


def load_and_stack_attentions_shard(shard_file: Path) -> tuple[list[Tensor], Tensor]:
    data = torch.load(shard_file, weights_only=True, mmap=True, map_location="cpu")
    return stack_attention_matrix(data["attentions"]), data["generated_tokens"]


def stack_attention_matrix(attentions: tuple[tuple[Tensor, ...], ...]) -> list[Tensor]:
    """Stacks attention scores for all tokens and layers of a single example into a single tensor.
    Dimensions of the input are (#num_gen_tokens, #num_layers, [batch_size x num_heads x seq_rows x seq_cols]).
    First token has full context (seq_rows == seq_cols), subsequent tokens have single rows.
    Dimensions of the output are (#num_layers, [batch_size x num_heads x total_rows x max_seq_cols]).
    """
    num_gen_tokens = len(attentions)
    num_layers = len(attentions[0])
    batch_size = attentions[0][0].size(0)
    num_heads = attentions[0][0].size(1)
    max_seq_cols = attentions[-1][0].size(-1)

    # Compute total rows: first token has multiple rows, rest have 1 row each
    total_rows = sum(attentions[t][0].size(-2) for t in range(num_gen_tokens))

    # Pre-allocate output tensors with zeros (handles padding implicitly)
    stacked_attn_all_tokens_per_single_layer: list[Tensor] = [
        torch.zeros(
            batch_size,
            num_heads,
            total_rows,
            max_seq_cols,
            dtype=attentions[0][0].dtype,
            device=attentions[0][0].device,
        )
        for _ in range(num_layers)
    ]

    # Fill in attention scores directly without creating intermediate padded tensors
    row_offset = 0
    for token_idx in range(num_gen_tokens):
        token_rows = attentions[token_idx][0].size(-2)
        token_cols = attentions[token_idx][0].size(-1)
        for layer_idx in range(num_layers):
            # Copy attention scores to pre-allocated tensor (already zero-padded)
            stacked_attn_all_tokens_per_single_layer[layer_idx][
                :, :, row_offset : row_offset + token_rows, :token_cols
            ] = attentions[token_idx][layer_idx]
        row_offset += token_rows

    return stacked_attn_all_tokens_per_single_layer


def compute_attention_metrics(
    attn: Tensor,
    input_length: int | None = None,
    n_jobs: int = 4,
    pool: mp.pool.Pool | None = None,
) -> dict[str, Tensor]:
    """Computes attention-based metrics.

    Args:
        attn: Attention matrix [#layers, #heads, #tokens, #tokens]
        input_length: Length of input/prompt tokens. If provided, enables
            computation of lookback_lens and mtopdiv.
        n_jobs: Number of parallel jobs for mtopdiv computation.

    Returns:
        Dictionary with attention metrics.
    """
    logger.debug("Computing eigenvalue features...")
    laplacian_diags = laplacian_diagonal_from_attn(attn, vertical_edges=False)
    attn_diags = attention_diagonal(attn)

    attn_eigvals_sorted, attn_eigvals_sorted_idx = torch.sort(attn_diags, dim=-1, descending=True)
    laplacian_eigvals_sorted, laplacian_eigvals_sorted_idx = torch.sort(
        laplacian_diags, dim=-1, descending=True
    )

    logger.debug("Computing sink score features...")
    sink_score_per_token = compute_sink_score_per_token_from_attention_matrix(attn)
    sink_sorted, sink_sorted_idx = torch.sort(sink_score_per_token, dim=-1, descending=True)

    result = {
        "laplacian_diags": laplacian_diags,
        "attn_diags": attn_diags,
        "sink_score_per_token": sink_score_per_token,
        "sink_score_per_token_sorted": sink_sorted,
        "sink_score_per_token_sorted_idx": sink_sorted_idx,
        "attention_eigvals_sorted": attn_eigvals_sorted,
        "attention_eigvals_sorted_idx": attn_eigvals_sorted_idx,
        "laplacian_eigvals_sorted": laplacian_eigvals_sorted,
        "laplacian_eigvals_sorted_idx": laplacian_eigvals_sorted_idx,
    }

    if input_length is not None:
        seq_len = attn.shape[-1]
        response_length = seq_len - input_length
        if response_length > 0:
            logger.debug("Computing lookback lens...")
            result["lookback_lens"] = compute_lookback_lens(attn, input_length)
            logger.debug("Computing mtopdiv...")
            result["mtopdiv"] = compute_mtopdiv(attn, response_length, n_jobs=n_jobs, pool=pool)
    else:
        logger.warning("Input length not provided. Lookback lens and mtopdiv will not be computed.")

    return result
