"""Compute primitive-based reasoning metrics from classified traces.

Produces per-trace metrics (counts, entropy, bigrams, transitions)
and aggregates them per-prompt and per-checkpoint.
"""
from __future__ import annotations

from collections import Counter
from typing import Optional

import numpy as np

from .primitive_classification import PRIMITIVES, Episode

# Primitives without OTHER for metric indexing
_PRIM_LABELS = [p for p in PRIMITIVES if p != "OTHER"]
_PRIM_TO_IDX = {p: i for i, p in enumerate(PRIMITIVES)}
_N_PRIMITIVES = len(PRIMITIVES)


# ---------------------------------------------------------------------------
# Per-trace metrics
# ---------------------------------------------------------------------------

def trace_primitive_counts(
    episodes: list[Episode],
    total_tokens: int,
) -> dict[str, float]:
    """Compute primitive counts per 1k tokens (length-normalized).

    Returns dict with keys like 'VERIFY_per_1k', 'BACKTRACK_per_1k', etc.
    """
    if total_tokens <= 0:
        return {f"{p}_per_1k": 0.0 for p in PRIMITIVES}

    counts = Counter(ep.label for ep in episodes)
    scale = 1000.0 / total_tokens
    return {f"{p}_per_1k": counts.get(p, 0) * scale for p in PRIMITIVES}


def trace_primitive_entropy(episodes: list[Episode]) -> float:
    """Shannon entropy over primitive label distribution for one trace."""
    if not episodes:
        return 0.0
    counts = Counter(ep.label for ep in episodes)
    total = sum(counts.values())
    probs = np.array([counts.get(p, 0) / total for p in PRIMITIVES])
    probs = probs[probs > 0]
    return float(-np.sum(probs * np.log2(probs)))


def trace_primitive_bigram_entropy(episodes: list[Episode]) -> float:
    """Entropy over (label_i, label_{i+1}) bigram transitions."""
    if len(episodes) < 2:
        return 0.0
    bigrams = [
        (episodes[i].label, episodes[i + 1].label)
        for i in range(len(episodes) - 1)
    ]
    counts = Counter(bigrams)
    total = sum(counts.values())
    probs = np.array(list(counts.values())) / total
    return float(-np.sum(probs * np.log2(probs)))


def trace_transition_matrix(episodes: list[Episode]) -> np.ndarray:
    """Compute 10x10 normalized transition count matrix.

    Rows = from, cols = to. Normalized per row (conditional probabilities).
    """
    mat = np.zeros((_N_PRIMITIVES, _N_PRIMITIVES), dtype=np.float64)
    for i in range(len(episodes) - 1):
        src = _PRIM_TO_IDX[episodes[i].label]
        dst = _PRIM_TO_IDX[episodes[i + 1].label]
        mat[src, dst] += 1

    # Normalize rows
    row_sums = mat.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1  # avoid division by zero
    return mat / row_sums


# Key bigram transitions to track
_KEY_BIGRAMS = [
    ("VERIFY", "BACKTRACK"),
    ("ENUMERATE", "VERIFY"),
    ("BACKTRACK", "COMPUTE"),
    ("DECOMPOSE", "ENUMERATE"),
    ("ERROR_DETECT", "BACKTRACK"),
    ("HYPOTHESIZE", "VERIFY"),
    ("COMPUTE", "VERIFY"),
]


def trace_key_bigram_counts(
    episodes: list[Episode],
    total_tokens: int,
) -> dict[str, float]:
    """Count key bigram transitions per 1k tokens."""
    if total_tokens <= 0 or len(episodes) < 2:
        return {f"{a}->{b}_per_1k": 0.0 for a, b in _KEY_BIGRAMS}

    bigrams = Counter(
        (episodes[i].label, episodes[i + 1].label)
        for i in range(len(episodes) - 1)
    )
    scale = 1000.0 / total_tokens
    return {
        f"{a}->{b}_per_1k": bigrams.get((a, b), 0) * scale
        for a, b in _KEY_BIGRAMS
    }


def trace_primitive_counts_raw(
    episodes: list[Episode],
) -> dict[str, int]:
    """Raw primitive episode counts per trace (not length-normalized)."""
    counts = Counter(ep.label for ep in episodes)
    return {f"{p}_count": counts.get(p, 0) for p in PRIMITIVES}


def trace_primitive_summary(
    episodes: list[Episode],
    total_tokens: int,
) -> dict:
    """Compute all per-trace primitive metrics in one dict."""
    result = {}

    # Counts per 1k tokens
    result.update(trace_primitive_counts(episodes, total_tokens))

    # Raw counts per trace
    result.update(trace_primitive_counts_raw(episodes))

    # Entropy metrics
    result["primitive_entropy"] = trace_primitive_entropy(episodes)
    result["primitive_bigram_entropy"] = trace_primitive_bigram_entropy(episodes)

    # Structural metrics
    labels = [ep.label for ep in episodes]
    result["n_unique_primitives"] = len(set(labels))
    result["n_episodes"] = len(episodes)
    result["n_transitions"] = max(0, len(episodes) - 1)

    # Key bigrams
    result.update(trace_key_bigram_counts(episodes, total_tokens))

    # Total tokens for reference
    result["total_tokens"] = total_tokens

    return result


# ---------------------------------------------------------------------------
# Aggregation
# ---------------------------------------------------------------------------

def aggregate_primitive_metrics(
    trace_summaries: list[dict],
    success_mask: Optional[list[bool]] = None,
) -> dict:
    """Aggregate per-trace primitive metrics across traces.

    Args:
        trace_summaries: List of dicts from trace_primitive_summary().
        success_mask: If provided, also compute metrics restricted to
            successful traces only.

    Returns dict with mean/std for each metric, plus successful-only variants.
    """
    if not trace_summaries:
        return {}

    # Collect numeric metric keys only (exclude lists, strings, etc.)
    metric_keys = [
        k for k in trace_summaries[0]
        if k != "total_tokens" and isinstance(trace_summaries[0][k], (int, float, np.integer, np.floating))
    ]

    result = {}
    values = {k: np.array([s[k] for s in trace_summaries], dtype=np.float64) for k in metric_keys}

    for k in metric_keys:
        result[f"{k}_mean"] = float(np.mean(values[k]))
        result[f"{k}_std"] = float(np.std(values[k]))

    result["n_traces"] = len(trace_summaries)
    result["avg_trace_length"] = float(np.mean([s["total_tokens"] for s in trace_summaries]))

    # Successful-only metrics
    if success_mask is not None:
        mask = np.array(success_mask)
        n_success = int(mask.sum())
        result["n_successful_traces"] = n_success

        if n_success >= 3:  # minimum for meaningful stats
            for k in metric_keys:
                success_vals = values[k][mask]
                result[f"successful_{k}_mean"] = float(np.mean(success_vals))
                result[f"successful_{k}_std"] = float(np.std(success_vals))

        # Failed-only metrics
        n_failed = int((~mask).sum())
        result["n_failed_traces"] = n_failed
        if n_failed >= 3:
            for k in metric_keys:
                fail_vals = values[k][~mask]
                result[f"failed_{k}_mean"] = float(np.mean(fail_vals))
                result[f"failed_{k}_std"] = float(np.std(fail_vals))

    return result
