"""MTopDiv (Manifold Topology Divergence) feature computation.

Adapted from toha_baseline repository:
- src/methods/mtopdiv/utils.py (transform_attention_scores_to_distances, transform_distances_to_mtopdiv)
- src/methods/mtopdiv/mtopdiv.py (get_mtopdivs, evaluation protocol)
"""

from __future__ import annotations

import itertools
import multiprocessing as mp

import numpy as np
import torch
from ripser import ripser
from torch import Tensor
from tqdm.auto import tqdm


def _run_parallel(
    pool: mp.pool.Pool, distance_list: list[np.ndarray], show_progress: bool
) -> list[float]:
    if show_progress:
        return list(
            tqdm(
                pool.imap(transform_distances_to_mtopdiv, distance_list),
                total=len(distance_list),
                desc="MTopDiv",
            )
        )
    return pool.map(transform_distances_to_mtopdiv, distance_list)


def compute_mtopdiv(
    item_attn: list[Tensor] | Tensor,
    response_length: int,
    n_jobs: int = 1,
    show_progress: bool = False,
    pool: mp.pool.Pool | None = None,
) -> Tensor:
    """Compute MTopDiv scores for all layer-head pairs.

    Input shape: [#layers, #heads, seq_len, seq_len]
    Output shape: [#layers, #heads]

    Evaluation protocol (from toha_baseline):
    1. Feature extraction: For each (layer, head) pair:
       - Convert attention to distance: distance = 1 - clip(attention, 0, inf)
       - Symmetrize: distance = min(distance, distance.T)
       - Zero out prompt-prompt region: distance[:prompt_len, :prompt_len] = 0
       - Compute persistent homology H0 barcode using ripser
       - MTopDiv = sum(barcode_lengths) / response_length
    2. Feature selection (hyperparameter tuning):
       - Supervised: RFE with LogisticRegression, select top N heads (1-6) by ROC-AUC
       - Unsupervised: Select heads with largest mean difference between hallucinated and grounded
         samples
    3. Prediction:
       - Supervised: LogisticRegression probability on selected head features
       - Unsupervised: Mean MTopDiv across selected heads
    4. Evaluation: K-fold cross-validation (default k=5) with train/val/test splits
    """
    attn = torch.stack(item_attn) if isinstance(item_attn, list) else item_attn
    num_layers, num_heads, seq_len, _ = attn.shape
    prompt_len = seq_len - response_length

    # Single numpy conversion for entire tensor
    attn_np = attn.detach().cpu().numpy()

    # Vectorized distance transformation (batched)
    distance_matrices = transform_attention_scores_to_distances(attn_np)

    # Vectorized prompt masking (once for all heads)
    if prompt_len > 0:
        distance_matrices[:, :, :prompt_len, :prompt_len] = 0.0

    # Create list of 2D matrices for ripser (views, not copies)
    distance_list = [
        distance_matrices[layer, head]
        for layer, head in itertools.product(range(num_layers), range(num_heads))
    ]

    # Run ripser (sequential or parallel)
    if pool is not None:
        raw_mtopdivs = _run_parallel(pool, distance_list, show_progress)
    elif n_jobs > 1:
        with mp.Pool(n_jobs) as temp_pool:
            raw_mtopdivs = _run_parallel(temp_pool, distance_list, show_progress)
    else:
        iterator = tqdm(distance_list, desc="MTopDiv") if show_progress else distance_list
        raw_mtopdivs = [transform_distances_to_mtopdiv(d) for d in iterator]

    # Vectorized normalization and reshape
    mtopdiv_array = np.array(raw_mtopdivs, dtype=np.float32).reshape(num_layers, num_heads)
    mtopdiv_array /= response_length
    return torch.from_numpy(mtopdiv_array)


def transform_attention_scores_to_distances(attention_weights: np.ndarray) -> np.ndarray:
    """Transform attention matrix to the matrix of distances between tokens.

    Input shape: [*, seq_len, seq_len] (supports batched input)
    Output shape: [*, seq_len, seq_len]
    """
    attention_weights = attention_weights.astype(np.float32)
    n_tokens = attention_weights.shape[-1]
    distance_mx = 1 - np.clip(attention_weights, a_min=0.0, a_max=None)
    zero_diag = np.ones((n_tokens, n_tokens)) - np.eye(n_tokens)
    distance_mx *= np.broadcast_to(zero_diag, distance_mx.shape)
    distance_mx = np.minimum(np.swapaxes(distance_mx, -1, -2), distance_mx)
    return distance_mx


def transform_distances_to_mtopdiv(distance_mx: np.ndarray) -> float:
    """Compute MTopDiv from a distance matrix using H0 persistent homology.

    Input shape: [seq_len, seq_len]
    Output: scalar (sum of H0 barcode lengths)
    """
    barcodes = ripser(distance_mx, distance_matrix=True, maxdim=0)["dgms"]
    if len(barcodes) > 0:
        return float(barcodes[0][:-1, 1].sum())
    return 0.0
