"""Compute diversity metrics from clustering results.

Provides per-prompt and per-checkpoint diversity measures based on
cluster membership distributions, including successful-only variants.
"""
from __future__ import annotations

from typing import Optional

import numpy as np


def cluster_entropy(labels: np.ndarray) -> float:
    """Shannon entropy over cluster membership proportions."""
    if len(labels) == 0:
        return 0.0
    _, counts = np.unique(labels, return_counts=True)
    probs = counts / counts.sum()
    return float(-np.sum(probs * np.log2(probs)))


def effective_num_paths(labels: np.ndarray) -> float:
    """Inverse Simpson index: 1 / sum(p_j^2).

    Ranges from 1 (all traces in one cluster) to N (uniform distribution).
    """
    if len(labels) == 0:
        return 0.0
    _, counts = np.unique(labels, return_counts=True)
    probs = counts / counts.sum()
    simpson = float(np.sum(probs ** 2))
    return 1.0 / simpson if simpson > 0 else 0.0


def top_cluster_mass(labels: np.ndarray) -> float:
    """Fraction of traces in the largest cluster."""
    if len(labels) == 0:
        return 0.0
    _, counts = np.unique(labels, return_counts=True)
    return float(counts.max() / counts.sum())


def num_clusters(labels: np.ndarray) -> int:
    """Number of distinct clusters."""
    if len(labels) == 0:
        return 0
    return len(np.unique(labels))


def prompt_diversity_metrics(
    labels: np.ndarray,
    success_mask: Optional[np.ndarray] = None,
) -> dict:
    """Compute all diversity metrics for a single prompt.

    Args:
        labels: Cluster labels for all traces (shape N,)
        success_mask: Boolean mask of successful traces (shape N,)

    Returns dict with all diversity metrics including successful-only variants.
    """
    result = {
        "num_clusters": num_clusters(labels),
        "cluster_entropy": cluster_entropy(labels),
        "effective_num_paths": effective_num_paths(labels),
        "top_cluster_mass": top_cluster_mass(labels),
        "n_traces": len(labels),
    }

    if success_mask is not None:
        success_mask = np.asarray(success_mask)
        n_success = int(success_mask.sum())
        result["n_successful"] = n_success

        if n_success >= 3:
            success_labels = labels[success_mask]
            result["successful_num_clusters"] = num_clusters(success_labels)
            result["successful_cluster_entropy"] = cluster_entropy(success_labels)
            result["successful_effective_num_paths"] = effective_num_paths(success_labels)
            result["successful_top_cluster_mass"] = top_cluster_mass(success_labels)

    return result


def aggregate_diversity(
    per_prompt_metrics: list[dict],
) -> dict:
    """Aggregate per-prompt diversity metrics across prompts.

    Computes mean and std for each metric.
    """
    if not per_prompt_metrics:
        return {}

    # Collect all metric keys
    all_keys = set()
    for m in per_prompt_metrics:
        all_keys.update(k for k in m if isinstance(m[k], (int, float)))

    result = {"n_prompts": len(per_prompt_metrics)}

    for key in sorted(all_keys):
        values = [m[key] for m in per_prompt_metrics if key in m]
        if values:
            result[f"{key}_mean"] = float(np.mean(values))
            result[f"{key}_std"] = float(np.std(values))

    return result
