"""Joint SFT+RL clustering to detect novel reasoning paths.

Clusters traces from both the SFT baseline and an RL checkpoint jointly,
then identifies clusters that are novel to RL (low SFT mass).
"""
from __future__ import annotations

from typing import Optional

import numpy as np

from .clustering import cluster_traces, combined_distance_matrix, embed_traces, primitive_ngram_vectors
from .diversity_metrics import cluster_entropy, effective_num_paths


def joint_cluster_novelty(
    sft_traces: list[str],
    rl_traces: list[str],
    sft_primitive_seqs: Optional[list[list[str]]] = None,
    rl_primitive_seqs: Optional[list[list[str]]] = None,
    tau: float = 0.1,
    distance_threshold: float = 0.3,
    semantic_weight: float = 0.7,
    embedder=None,
) -> dict:
    """Compute cluster novelty by jointly clustering SFT + RL traces.

    A cluster is "novel" if the fraction of its traces from SFT < tau.
    This means RL discovered a reasoning path that SFT doesn't use.

    Args:
        sft_traces: List of SFT trace texts
        rl_traces: List of RL trace texts
        sft_primitive_seqs: Optional primitive label sequences for SFT traces
        rl_primitive_seqs: Optional primitive label sequences for RL traces
        tau: Threshold for novel cluster detection (SFT mass < tau)
        distance_threshold: For agglomerative clustering
        semantic_weight: Weight for semantic vs primitive distance
        embedder: Optional pre-loaded embedder

    Returns dict with novelty metrics.
    """
    n_sft = len(sft_traces)
    n_rl = len(rl_traces)
    n_total = n_sft + n_rl

    if n_total == 0:
        return _empty_novelty()

    # Combine traces
    all_traces = sft_traces + rl_traces
    is_rl = np.array([False] * n_sft + [True] * n_rl)

    # Compute semantic embeddings
    all_embs = embed_traces(all_traces, embedder)

    # Compute primitive n-gram vectors if available
    if sft_primitive_seqs is not None and rl_primitive_seqs is not None:
        all_prim_seqs = sft_primitive_seqs + rl_primitive_seqs
        prim_vecs = primitive_ngram_vectors(all_prim_seqs)
        dist_mat = combined_distance_matrix(all_embs, prim_vecs, semantic_weight)
    else:
        # Semantic-only distance
        from scipy.spatial.distance import cdist
        dist_mat = cdist(all_embs, all_embs, metric="cosine")
        dist_mat = np.nan_to_num(dist_mat, nan=1.0)

    # Cluster jointly
    labels = cluster_traces(dist_mat, distance_threshold)

    # Compute per-cluster SFT/RL mass
    unique_labels = np.unique(labels)
    novel_mask = np.zeros(n_total, dtype=bool)

    for cl in unique_labels:
        cl_mask = labels == cl
        cl_size = cl_mask.sum()
        sft_in_cluster = (~is_rl & cl_mask).sum()
        sft_mass = sft_in_cluster / cl_size if cl_size > 0 else 0.0

        if sft_mass < tau:
            novel_mask |= cl_mask

    # Compute metrics on RL traces only
    rl_mask = is_rl
    rl_in_novel = (rl_mask & novel_mask).sum()

    return {
        "n_clusters": len(unique_labels),
        "n_sft_traces": n_sft,
        "n_rl_traces": n_rl,
        "novel_cluster_mass_rl": float(rl_in_novel / n_rl) if n_rl > 0 else 0.0,
        "fraction_rl_in_novel_clusters": float(rl_in_novel / n_rl) if n_rl > 0 else 0.0,
        "n_novel_clusters": int((novel_mask & rl_mask).any()),  # count of novel clusters
        "cluster_labels": labels,
        "novel_mask": novel_mask,
        "is_rl": is_rl,
    }


def joint_cluster_novelty_with_success(
    sft_traces: list[str],
    rl_traces: list[str],
    rl_success_mask: list[bool],
    sft_primitive_seqs: Optional[list[list[str]]] = None,
    rl_primitive_seqs: Optional[list[list[str]]] = None,
    tau: float = 0.1,
    distance_threshold: float = 0.3,
    semantic_weight: float = 0.7,
    embedder=None,
) -> dict:
    """Like joint_cluster_novelty but also computes successful-only metrics."""
    result = joint_cluster_novelty(
        sft_traces, rl_traces,
        sft_primitive_seqs, rl_primitive_seqs,
        tau, distance_threshold, semantic_weight, embedder,
    )

    # Add successful-only novelty metrics
    n_rl = len(rl_traces)
    rl_success = np.array(rl_success_mask)
    novel_mask = result["novel_mask"]
    is_rl = result["is_rl"]

    # RL success mask in the joint array
    full_success = np.zeros(len(is_rl), dtype=bool)
    full_success[is_rl] = rl_success

    rl_success_in_novel = (is_rl & novel_mask & full_success).sum()
    n_rl_success = int(rl_success.sum())

    result["successful_novel_cluster_mass_rl"] = (
        float(rl_success_in_novel / n_rl_success) if n_rl_success > 0 else 0.0
    )
    result["fraction_successful_rl_in_novel_clusters"] = (
        float(rl_success_in_novel / n_rl_success) if n_rl_success > 0 else 0.0
    )
    result["n_successful_rl_traces"] = n_rl_success

    return result


def aggregate_novelty(per_prompt_results: list[dict]) -> dict:
    """Aggregate per-prompt novelty results across prompts."""
    if not per_prompt_results:
        return {}

    scalar_keys = [
        "novel_cluster_mass_rl",
        "fraction_rl_in_novel_clusters",
        "successful_novel_cluster_mass_rl",
        "fraction_successful_rl_in_novel_clusters",
        "n_clusters",
    ]

    result = {"n_prompts": len(per_prompt_results)}
    for key in scalar_keys:
        values = [r[key] for r in per_prompt_results if key in r]
        if values:
            result[f"{key}_mean"] = float(np.mean(values))
            result[f"{key}_std"] = float(np.std(values))

    return result


def _empty_novelty() -> dict:
    """Return empty novelty metrics."""
    return {
        "n_clusters": 0,
        "n_sft_traces": 0,
        "n_rl_traces": 0,
        "novel_cluster_mass_rl": 0.0,
        "fraction_rl_in_novel_clusters": 0.0,
        "n_novel_clusters": 0,
        "cluster_labels": np.array([], dtype=int),
        "novel_mask": np.array([], dtype=bool),
        "is_rl": np.array([], dtype=bool),
    }
