import multiprocessing
import os
import sys
from pathlib import Path
from typing import Annotated

import torch
import typer
from loguru import logger
from torch import Tensor
from tqdm.auto import tqdm

from hallucinations.dirs import DatasetDir
from hallucinations.features.attention_weights import (
    compute_attention_metrics,
    stack_attention_matrix,
)
from hallucinations.features.hidden_states import yield_hidden_states

logger.remove()
logger.add(sys.stderr, level=os.getenv("LOGURU_LEVEL", "INFO"))


def main(
    dataset_dir: Path = typer.Argument(..., help="Path to the dataset directory"),
    n_jobs: Annotated[int, typer.Option(help="Number of parallel jobs for mtopdiv")] = 4,
    force: Annotated[
        bool,
        typer.Option("--force", "-f", help="Overwrite existing metrics file"),
    ] = False,
    svd_metrics: Annotated[
        bool,
        typer.Option(
            "--svd-metrics/--no-svd-metrics",
            help="Compute SVD-based metrics (matrix_entropy, anisotropy)",
        ),
    ] = False,
    prefetch_shards: Annotated[int, typer.Option(help="Number of shards to prefetch")] = 1,
) -> None:
    ds_dir = DatasetDir(dataset_dir)
    metrics_file = ds_dir.internal_states_metrics_file

    if metrics_file.exists() and not force:
        raise FileExistsError(
            f"Metrics file already exists: {metrics_file}. Use --force to overwrite."
        )

    has_full_shards = ds_dir.full_hidden_states_with_attentions_dir.exists()
    has_legacy_features = ds_dir.attn_diags_file.exists() and ds_dir.laplacian_diags_file.exists()

    if has_full_shards:
        logger.info("Computing metrics from full hidden states with attentions shards...")
        result = compute_from_full_shards(ds_dir, n_jobs, svd_metrics, prefetch_shards)
    elif has_legacy_features:
        logger.warning(
            "Full shards not found. Loading from legacy feature files. "
            "Lookback lens and mtopdiv will be skipped."
        )
        result = compute_from_legacy_features(ds_dir)
    else:
        raise FileNotFoundError(
            f"Neither full shards ({ds_dir.full_hidden_states_with_attentions_dir}) "
            f"nor legacy features ({ds_dir.attn_diags_file}, {ds_dir.laplacian_diags_file}) found."
        )

    torch.save(result, metrics_file)
    logger.info(f"Saved metrics to {metrics_file}")


def compute_from_full_shards(
    ds_dir: DatasetDir,
    n_jobs: int,
    svd_metrics: bool,
    prefetch_shards: int,
) -> dict:
    hs_metrics = []
    attn_metrics = []
    generated_tokens = []
    input_lengths = []

    with multiprocessing.Pool(n_jobs) as pool:
        for shard in yield_hidden_states(
            ds_dir.full_hidden_states_with_attentions_dir,
            hs_selection=None,
            prefetch_buffer_size=prefetch_shards,
            dtype=torch.float32,
        ):
            logger.debug("Computing hidden state metrics...")
            hs_metrics.append(compute_hs_metrics(shard["hidden_states"], svd_metrics))

            logger.debug("Preparing attention matrix....")
            attn = torch.stack(
                [a.squeeze(0) for a in stack_attention_matrix(shard["attentions"])]
            ).float()

            input_length = int(shard["input_length"])
            input_lengths.append(input_length)

            logger.debug("Computing attention metrics...")
            attn_metrics.append(compute_attention_metrics(attn, input_length, n_jobs, pool=pool))
            generated_tokens.append(shard["generated_tokens"].squeeze(0).clone())

    return {
        "input_lengths": torch.tensor(input_lengths),
        "generated_tokens": generated_tokens,
        "hs_metrics": hs_metrics,
        "attn_metrics": attn_metrics,
    }


def compute_from_legacy_features(ds_dir: DatasetDir) -> dict:
    logger.info(f"\tLoading attn_diags from {ds_dir.attn_diags_file}")
    attn_diags_list = torch.load(ds_dir.attn_diags_file, mmap=True, map_location="cpu")

    logger.info(f"\tLoading laplacian_diags from {ds_dir.laplacian_diags_file}")
    laplacian_diags_list = torch.load(ds_dir.laplacian_diags_file, mmap=True, map_location="cpu")

    assert isinstance(attn_diags_list, list)
    assert isinstance(laplacian_diags_list, list)
    assert len(attn_diags_list) == len(laplacian_diags_list)

    attn_metrics = []
    for idx in tqdm(range(len(attn_diags_list)), desc="Computing sink scores from legacy"):
        attn_diags = attn_diags_list[idx].float()
        laplacian_diags = laplacian_diags_list[idx].float()

        sink_score_per_token = laplacian_diags + attn_diags
        sink_sorted, sink_sorted_idx = torch.sort(sink_score_per_token, dim=-1, descending=True)

        attn_metrics.append(
            {
                "attn_diags": attn_diags,
                "laplacian_diags": laplacian_diags,
                "sink_score_per_token": sink_score_per_token,
                "sink_score_per_token_sorted": sink_sorted,
                "sink_score_per_token_sorted_idx": sink_sorted_idx,
            }
        )

    return {
        "input_lengths": torch.tensor([]),
        "generated_tokens": [],
        "hs_metrics": [],
        "attn_metrics": attn_metrics,
    }


def compute_hs_metrics(hidden_states: list[Tensor], compute_svd: bool) -> dict[str, Tensor]:
    token_norms = []
    matrix_entropies = []
    anisotropies = []

    for layer_hs in hidden_states:
        layer_hs = layer_hs.squeeze(0).float()
        tok_norm = torch.linalg.vector_norm(layer_hs, dim=-1)
        token_norms.append(tok_norm)

        if compute_svd:
            sigma_squared = torch.linalg.svdvals(layer_hs).pow(2)
            froben_squared = sigma_squared.sum()
            svd_normalized = sigma_squared / froben_squared

            entropy = -torch.sum(svd_normalized * torch.log(svd_normalized))
            anisotropy = svd_normalized[0]
            matrix_entropies.append(entropy)
            anisotropies.append(anisotropy)

    result: dict[str, Tensor] = {"token_norms": torch.stack(token_norms)}
    if compute_svd:
        result["matrix_entropy"] = torch.tensor(matrix_entropies)
        result["anisotropy"] = torch.tensor(anisotropies)

    return result


if __name__ == "__main__":
    typer.run(main)
