"""Primitive sequence diversity analysis.

Two approaches:
1. k-mer frequency vectors (bigram + trigram) for clustering/distance
2. Directly-follows graphs per prompt for aggregate transition analysis
"""
from __future__ import annotations

import json
from collections import Counter
from itertools import combinations
from typing import Optional

import numpy as np
from scipy.spatial.distance import cosine, jensenshannon
from scipy.stats import entropy as scipy_entropy

from .primitive_classification import PRIMITIVES, LEARNED_PRIMITIVES

# Active label set — switch to LEARNED_PRIMITIVES when using the BERT ensemble.
# Default: legacy 10-label heuristic set.
_PRIM_LIST = [p for p in PRIMITIVES]  # includes OTHER
_PRIM_IDX = {p: i for i, p in enumerate(_PRIM_LIST)}
_N = len(_PRIM_LIST)


def use_learned_labels():
    """Switch DFG/k-mer analysis to the 9-label learned taxonomy.

    Call once before running any DFG analysis if using the BERT ensemble.
    """
    global _PRIM_LIST, _PRIM_IDX, _N
    _PRIM_LIST = list(LEARNED_PRIMITIVES)
    _PRIM_IDX = {p: i for i, p in enumerate(_PRIM_LIST)}
    _N = len(_PRIM_LIST)


# =====================================================================
# k-mer frequency vectors
# =====================================================================

def kmer_counts(seq: list[str], k: int) -> Counter:
    """Count all k-mers (subsequences of length k) in a primitive sequence."""
    return Counter(tuple(seq[i:i + k]) for i in range(len(seq) - k + 1))


def kmer_vector(seq: list[str], k: int, vocab: list[tuple]) -> np.ndarray:
    """Convert a sequence to a k-mer frequency vector over a fixed vocabulary."""
    counts = kmer_counts(seq, k)
    vec = np.array([counts.get(kmer, 0) for kmer in vocab], dtype=np.float64)
    total = vec.sum()
    if total > 0:
        vec /= total  # normalize to distribution
    return vec


def build_kmer_vocab(k: int) -> list[tuple]:
    """Build the full vocabulary of k-mers over the primitive alphabet."""
    if k == 1:
        return [(p,) for p in _PRIM_LIST]
    elif k == 2:
        return [(a, b) for a in _PRIM_LIST for b in _PRIM_LIST]
    elif k == 3:
        return [(a, b, c) for a in _PRIM_LIST for b in _PRIM_LIST for c in _PRIM_LIST]
    else:
        raise ValueError(f"k={k} too large, vocab would be {_N**k}")


def combined_kmer_vector(seq: list[str], k_values: tuple = (2, 3)) -> np.ndarray:
    """Concatenate k-mer vectors for multiple k values."""
    parts = []
    for k in k_values:
        if len(seq) < k:
            # Sequence too short for this k
            vocab = build_kmer_vocab(k)
            parts.append(np.zeros(len(vocab)))
        else:
            vocab = build_kmer_vocab(k)
            parts.append(kmer_vector(seq, k, vocab))
    return np.concatenate(parts)


# --- k-mer based diversity metrics per prompt ---

def kmer_diversity_per_prompt(
    sequences: list[list[str]],
    k_values: tuple = (2, 3),
) -> dict:
    """Compute k-mer based diversity metrics for a set of sequences (one prompt).

    Returns dict with:
        mean_cosine_dist: mean pairwise cosine distance between k-mer vectors
        mean_js_divergence: mean pairwise Jensen-Shannon divergence
        n_unique_kmers_union: total unique k-mers across all sequences
        n_traces: number of sequences
    """
    if len(sequences) < 2:
        return {
            "kmer_mean_cosine_dist": 0.0,
            "kmer_mean_js_divergence": 0.0,
            "kmer_n_unique_bigrams_union": 0,
            "kmer_n_unique_trigrams_union": 0,
            "n_traces": len(sequences),
        }

    # Compute k-mer vectors
    vecs = [combined_kmer_vector(seq, k_values) for seq in sequences]

    # Pairwise cosine distance
    pairs = list(combinations(range(len(vecs)), 2))
    if len(pairs) > 200:
        rng = np.random.RandomState(42)
        pairs = [pairs[i] for i in rng.choice(len(pairs), 200, replace=False)]

    cos_dists = []
    js_divs = []
    for i, j in pairs:
        # Cosine distance
        if np.any(vecs[i]) and np.any(vecs[j]):
            cos_dists.append(cosine(vecs[i], vecs[j]))
        else:
            cos_dists.append(1.0)
        # Jensen-Shannon on bigram part only (first _N*_N elements)
        bi_size = _N * _N
        bi_i = vecs[i][:bi_size]
        bi_j = vecs[j][:bi_size]
        s_i = bi_i.sum()
        s_j = bi_j.sum()
        if s_i > 0 and s_j > 0:
            js_divs.append(jensenshannon(bi_i / s_i, bi_j / s_j))
        else:
            js_divs.append(1.0)

    # Count unique k-mers across all sequences
    all_bigrams = set()
    all_trigrams = set()
    for seq in sequences:
        all_bigrams.update(kmer_counts(seq, 2).keys())
        if len(seq) >= 3:
            all_trigrams.update(kmer_counts(seq, 3).keys())

    return {
        "kmer_mean_cosine_dist": float(np.mean(cos_dists)),
        "kmer_std_cosine_dist": float(np.std(cos_dists)),
        "kmer_mean_js_divergence": float(np.nanmean(js_divs)),
        "kmer_n_unique_bigrams_union": len(all_bigrams),
        "kmer_n_unique_trigrams_union": len(all_trigrams),
        "n_traces": len(sequences),
    }


# =====================================================================
# Directly-follows graph
# =====================================================================

def build_dfg(sequences: list[list[str]]) -> np.ndarray:
    """Build a directly-follows graph from multiple sequences.

    Returns a (_N, _N) matrix where entry [i,j] = count of transitions
    from primitive i to primitive j across all sequences.
    """
    mat = np.zeros((_N, _N), dtype=np.float64)
    for seq in sequences:
        for a, b in zip(seq[:-1], seq[1:]):
            i = _PRIM_IDX.get(a)
            j = _PRIM_IDX.get(b)
            if i is not None and j is not None:
                mat[i, j] += 1
    return mat


def dfg_to_distribution(mat: np.ndarray) -> np.ndarray:
    """Normalize a DFG matrix to a probability distribution over edges."""
    total = mat.sum()
    if total > 0:
        return mat / total
    return mat


def dfg_outgoing_entropy(mat: np.ndarray) -> dict:
    """Compute per-node outgoing edge entropy and the weighted aggregate.

    For each primitive node i, H_out(i) = entropy over the row distribution
    P(next=j | current=i).  The aggregate is the conditional entropy
    H(next | current) = sum_i p(current=i) * H_out(i), which is a single
    scalar summarising how flexible/predictable transitions are overall.

    Args:
        mat: (_N, _N) raw count matrix

    Returns dict with:
        outgoing_entropy_<PRIM>  — per-node H_out in bits (NaN if node unseen)
        mean_outgoing_entropy    — H(next|current) in bits (weighted mean)
    """
    result = {}
    row_sums = mat.sum(axis=1)          # outgoing mass per node
    total = row_sums.sum()

    per_node = {}
    for i, prim in enumerate(_PRIM_LIST):
        if row_sums[i] == 0:
            per_node[prim] = float("nan")
        else:
            row = mat[i] / row_sums[i]
            row = row[row > 0]
            per_node[prim] = float(scipy_entropy(row, base=2))
        result[f"outgoing_entropy_{prim}"] = per_node[prim]

    # Weighted mean: H(next|current)
    if total > 0:
        weights = row_sums / total
        weighted = sum(
            weights[i] * per_node[p]
            for i, p in enumerate(_PRIM_LIST)
            if not np.isnan(per_node[p])
        )
        result["mean_outgoing_entropy"] = float(weighted)
    else:
        result["mean_outgoing_entropy"] = float("nan")

    return result


def dfg_metrics(mat: np.ndarray) -> dict:
    """Compute metrics from a directly-follows graph.

    Args:
        mat: (_N, _N) raw count matrix

    Returns dict with edge count, entropy, etc.
    """
    total = mat.sum()
    if total == 0:
        return {
            "dfg_n_edges": 0,
            "dfg_edge_entropy": 0.0,
            "dfg_max_edge_weight": 0.0,
            "dfg_total_transitions": 0,
            "mean_outgoing_entropy": 0.0,
        }

    # Number of distinct edges (non-zero entries)
    n_edges = int((mat > 0).sum())

    # Edge entropy (over flattened distribution)
    dist = mat.flatten()
    dist = dist[dist > 0]
    dist = dist / dist.sum()
    edge_entropy = float(scipy_entropy(dist, base=2))

    # Max edge weight (most common transition)
    max_weight = float(mat.max() / total)

    result = {
        "dfg_n_edges": n_edges,
        "dfg_edge_entropy": edge_entropy,
        "dfg_max_edge_weight": max_weight,
        "dfg_total_transitions": int(total),
    }
    result.update(dfg_outgoing_entropy(mat))
    return result


def dfg_novel_edges(
    sft_mat: np.ndarray,
    gspo_mat: np.ndarray,
    threshold: float = 0.0,
) -> dict:
    """Find edges present in GSPO's DFG but absent (or below threshold) in SFT's.

    Args:
        sft_mat: SFT's raw count DFG
        gspo_mat: GSPO's raw count DFG
        threshold: minimum fraction in SFT to NOT be considered novel

    Returns dict with novel edge info.
    """
    sft_total = sft_mat.sum()
    gspo_total = gspo_mat.sum()

    if gspo_total == 0:
        return {"n_novel_edges": 0, "novel_edge_mass": 0.0, "novel_edges": []}

    sft_dist = sft_mat / sft_total if sft_total > 0 else sft_mat
    gspo_dist = gspo_mat / gspo_total

    novel_edges = []
    novel_mass = 0.0
    for i in range(_N):
        for j in range(_N):
            if gspo_dist[i, j] > 0 and sft_dist[i, j] <= threshold:
                novel_edges.append((_PRIM_LIST[i], _PRIM_LIST[j], float(gspo_dist[i, j])))
                novel_mass += gspo_dist[i, j]

    return {
        "n_novel_edges": len(novel_edges),
        "novel_edge_mass": float(novel_mass),
        "novel_edges": sorted(novel_edges, key=lambda x: -x[2]),
    }


def dfg_divergence(sft_mat: np.ndarray, gspo_mat: np.ndarray) -> dict:
    """Compute divergence between SFT and GSPO transition distributions."""
    sft_dist = dfg_to_distribution(sft_mat).flatten()
    gspo_dist = dfg_to_distribution(gspo_mat).flatten()

    # Add small epsilon for numerical stability
    eps = 1e-10
    sft_dist = sft_dist + eps
    gspo_dist = gspo_dist + eps
    sft_dist /= sft_dist.sum()
    gspo_dist /= gspo_dist.sum()

    js = float(jensenshannon(sft_dist, gspo_dist))

    return {
        "dfg_js_divergence": js,
    }


# --- Combined per-prompt analysis ---

def prompt_sequence_analysis(
    sequences: list[list[str]],
    sft_sequences: Optional[list[list[str]]] = None,
    k_values: tuple = (2, 3),
) -> dict:
    """Full sequence diversity analysis for one prompt.

    Args:
        sequences: primitive sequences for this checkpoint's rollouts
        sft_sequences: if provided, also compute novelty vs SFT
        k_values: k-mer sizes

    Returns dict with all metrics.
    """
    result = {}

    # k-mer diversity
    result.update(kmer_diversity_per_prompt(sequences, k_values))

    # DFG metrics
    mat = build_dfg(sequences)
    result.update(dfg_metrics(mat))

    # Novelty vs SFT
    if sft_sequences is not None:
        sft_mat = build_dfg(sft_sequences)
        result.update(dfg_novel_edges(sft_mat, mat))
        result.update(dfg_divergence(sft_mat, mat))
        result.update(dfg_metrics_prefixed(sft_mat, prefix="sft_"))

    return result


def dfg_metrics_prefixed(mat: np.ndarray, prefix: str) -> dict:
    """Compute DFG metrics with a key prefix."""
    m = dfg_metrics(mat)
    return {f"{prefix}{k}": v for k, v in m.items()}


# --- Aggregate across prompts ---

def aggregate_sequence_diversity(per_prompt: list[dict]) -> dict:
    """Aggregate per-prompt sequence diversity metrics."""
    if not per_prompt:
        return {}

    scalar_keys = [
        k for k in per_prompt[0]
        if isinstance(per_prompt[0][k], (int, float, np.integer, np.floating))
    ]

    result = {"n_prompts": len(per_prompt)}
    for k in scalar_keys:
        vals = [p[k] for p in per_prompt if k in p]
        if vals:
            result[f"{k}_mean"] = float(np.mean(vals))
            result[f"{k}_std"] = float(np.std(vals))

    return result
