"""Cluster traces per prompt using semantic + primitive representations.

Supports agglomerative clustering with a distance threshold (no fixed k),
combining sentence embeddings (semantic) and primitive n-gram vectors.
"""
from __future__ import annotations

import sys
from pathlib import Path
from typing import Optional

import numpy as np
from scipy.spatial.distance import cdist
from sklearn.cluster import AgglomerativeClustering

from .primitive_classification import PRIMITIVES, Episode

# Import MiniLMEmbedder from existing trace similarity code
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "analysis" / "trace_similarity"))

_embedder_instance = None


def _get_embedder():
    """Lazy-load MiniLMEmbedder singleton."""
    global _embedder_instance
    if _embedder_instance is None:
        from metric import MiniLMEmbedder
        _embedder_instance = MiniLMEmbedder()
    return _embedder_instance


# ---------------------------------------------------------------------------
# Semantic embeddings
# ---------------------------------------------------------------------------

def embed_traces(
    traces: list[str],
    embedder=None,
) -> np.ndarray:
    """Compute mean-pooled sentence embeddings for each trace.

    Returns shape (N, 384) with L2-normalized vectors.
    """
    if embedder is None:
        embedder = _get_embedder()

    from metric import split_sentences

    embeddings = []
    for trace in traces:
        sents = split_sentences(trace)
        if not sents:
            embeddings.append(np.zeros(384, dtype=np.float32))
            continue
        sent_embs = embedder.encode(sents)  # (n_sents, 384), L2-normalized
        # Mean-pool and re-normalize
        mean_emb = sent_embs.mean(axis=0)
        norm = np.linalg.norm(mean_emb)
        if norm > 0:
            mean_emb /= norm
        embeddings.append(mean_emb)

    return np.array(embeddings, dtype=np.float32)


# ---------------------------------------------------------------------------
# Primitive n-gram vectors
# ---------------------------------------------------------------------------

# Build bigram vocabulary: all pairs of primitives
_UNIGRAMS = PRIMITIVES  # 10 labels
_BIGRAMS = [(a, b) for a in PRIMITIVES for b in PRIMITIVES]  # 100 pairs
_VOCAB = _UNIGRAMS + [f"{a}_{b}" for a, b in _BIGRAMS]  # 110 features
_VOCAB_IDX = {v: i for i, v in enumerate(_VOCAB)}


def primitive_ngram_vector(episode_labels: list[str]) -> np.ndarray:
    """Compute a unigram + bigram count vector over primitive labels.

    Returns a 110-dimensional vector (10 unigrams + 100 bigrams).
    """
    vec = np.zeros(len(_VOCAB), dtype=np.float32)

    # Unigrams
    for label in episode_labels:
        idx = _VOCAB_IDX.get(label)
        if idx is not None:
            vec[idx] += 1

    # Bigrams
    for i in range(len(episode_labels) - 1):
        bigram = f"{episode_labels[i]}_{episode_labels[i + 1]}"
        idx = _VOCAB_IDX.get(bigram)
        if idx is not None:
            vec[idx] += 1

    # Normalize to unit length (like TF-IDF without IDF)
    norm = np.linalg.norm(vec)
    if norm > 0:
        vec /= norm

    return vec


def primitive_ngram_vectors(
    episode_sequences: list[list[str]],
) -> np.ndarray:
    """Compute primitive n-gram vectors for multiple traces.

    Args:
        episode_sequences: List of primitive label sequences, one per trace.

    Returns shape (N, 110) array.
    """
    return np.array(
        [primitive_ngram_vector(seq) for seq in episode_sequences],
        dtype=np.float32,
    )


# ---------------------------------------------------------------------------
# Distance matrix
# ---------------------------------------------------------------------------

def combined_distance_matrix(
    semantic_embs: np.ndarray,
    primitive_vecs: np.ndarray,
    semantic_weight: float = 0.7,
) -> np.ndarray:
    """Compute weighted combination of cosine distances.

    Args:
        semantic_embs: (N, 384) L2-normalized embeddings
        primitive_vecs: (N, 110) L2-normalized n-gram vectors
        semantic_weight: weight for semantic distance (1 - weight for primitive)

    Returns (N, N) distance matrix.
    """
    # Cosine distance = 1 - cosine_similarity
    sem_dist = cdist(semantic_embs, semantic_embs, metric="cosine")
    prim_dist = cdist(primitive_vecs, primitive_vecs, metric="cosine")

    # Handle NaN from zero vectors
    sem_dist = np.nan_to_num(sem_dist, nan=1.0)
    prim_dist = np.nan_to_num(prim_dist, nan=1.0)

    return semantic_weight * sem_dist + (1 - semantic_weight) * prim_dist


# ---------------------------------------------------------------------------
# Clustering
# ---------------------------------------------------------------------------

def cluster_traces(
    distance_matrix: np.ndarray,
    distance_threshold: float = 0.3,
    linkage: str = "average",
) -> np.ndarray:
    """Cluster traces using agglomerative clustering with a distance threshold.

    Returns cluster labels array of shape (N,).
    No fixed number of clusters — the threshold determines granularity.
    """
    n = distance_matrix.shape[0]
    if n <= 1:
        return np.zeros(n, dtype=int)

    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=distance_threshold,
        metric="precomputed",
        linkage=linkage,
    )
    return clustering.fit_predict(distance_matrix)


def cluster_traces_semantic_only(
    traces: list[str],
    distance_threshold: float = 0.3,
    embedder=None,
) -> tuple[np.ndarray, np.ndarray]:
    """Convenience: embed + cluster traces using semantic similarity only.

    Returns (cluster_labels, embeddings).
    """
    embs = embed_traces(traces, embedder)
    dist_mat = cdist(embs, embs, metric="cosine")
    dist_mat = np.nan_to_num(dist_mat, nan=1.0)
    labels = cluster_traces(dist_mat, distance_threshold)
    return labels, embs
