from __future__ import annotations

from dataclasses import dataclass
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")

import numpy as np

# NOTE: We intentionally avoid heavy clustering backends by default.
# Some environments can hang due to OpenMP/threading interactions in certain
# sklearn clusterers. CE-Graph does not require a specific clustering algorithm;
# it only requires *consistent* bucketing into recurring failure modes.

from .counterexamples import ExecutionTrace


@dataclass
class FailureSignature:
    """Failure signature \phi(\tau).

    Keep it simple and auditable: a short text string + structured flags.
    """

    text: str
    flags: Dict[str, Any]


@dataclass
class FailureMode:
    """A clustered failure mode b in the paper."""

    mode_id: int
    size: int
    exemplar_indices: List[int]
    keywords: List[str]


class FailureClustering:
    """Cluster execution failures into recurring modes.

    Default implementation uses TF-IDF over failure signature text and KMeans.
    This is intentionally lightweight; you can swap it with embedding-based
    clustering if desired.
    """

    def __init__(
        self,
        max_modes: int = 12,
        min_cluster_size: int = 10,
        random_state: int = 0,
    ):
        self.max_modes = int(max_modes)
        self.min_cluster_size = int(min_cluster_size)
        self.random_state = int(random_state)

    def signature(self, trace: ExecutionTrace) -> FailureSignature:
        # Convert trace into a stable, low-variance textual representation.
        # You should customize this to your environments / tools.
        pieces: List[str] = []
        flags: Dict[str, Any] = {}

        err = (trace.error or "").strip()
        if err:
            pieces.append(f"error:{err[:200]}")
            flags["has_error"] = True
        else:
            flags["has_error"] = False

        # Common structure: tool failures / invalid JSON / verifier failed
        tool_errors = []
        invalid_json = 0
        for step in trace.node_outputs:
            if step.get("tool_error"):
                tool_errors.append(str(step.get("tool_error"))[:120])
            if step.get("invalid_json"):
                invalid_json += 1

        if tool_errors:
            pieces.append("tool:" + " | ".join(tool_errors[:3]))
            flags["tool_error"] = True
        else:
            flags["tool_error"] = False

        if invalid_json:
            pieces.append(f"invalid_json:{invalid_json}")
            flags["invalid_json"] = True
        else:
            flags["invalid_json"] = False

        # Add a compact view of scalar metrics
        if trace.metrics:
            for k in sorted(trace.metrics.keys()):
                v = trace.metrics[k]
                if isinstance(v, (int, float)):
                    pieces.append(f"m:{k}={float(v):.3f}")
        
        # Fallback: include last output snippet for discriminating modes
        if trace.final_output is not None:
            s = str(trace.final_output)
            pieces.append("out:" + s[:200])

        text = " ; ".join(pieces) if pieces else "unknown_failure"
        return FailureSignature(text=text, flags=flags)

    def cluster(self, traces: Sequence[ExecutionTrace]) -> Tuple[List[FailureMode], np.ndarray]:
        """Return (modes, assignments).

        assignments[i] = mode_id for traces[i].
        """
        if len(traces) == 0:
            return [], np.array([], dtype=np.int64)

        sigs = [self.signature(t) for t in traces]
        texts = [s.text for s in sigs]

        # Rule-based bucketing by stable keys (fast, deterministic, no hangs).
        buckets: Dict[Tuple[str, bool, bool], List[int]] = {}
        for i, t in enumerate(traces):
            err = (t.error or "").split("\n")[0].strip().lower()
            s = sigs[i]
            key = (err[:60] if err else "no_error", bool(s.flags.get("tool_error")), bool(s.flags.get("invalid_json")))
            buckets.setdefault(key, []).append(i)

        # Sort buckets by size and cap to max_modes.
        bucket_items = sorted(buckets.items(), key=lambda kv: -len(kv[1]))
        bucket_items = bucket_items[: self.max_modes]

        labels = np.full(len(traces), -1, dtype=np.int64)
        modes: List[FailureMode] = []
        for mode_id, (key, idx) in enumerate(bucket_items):
            for i in idx:
                labels[i] = mode_id

            # Keywords: simple token frequency within the bucket.
            kws = self._keywords_from_texts([texts[i] for i in idx], topn=8)
            modes.append(
                FailureMode(
                    mode_id=mode_id,
                    size=len(idx),
                    exemplar_indices=idx[: min(5, len(idx))],
                    keywords=kws,
                )
            )

        if not modes:
            return [self._singleton_mode(texts, 0, list(range(len(texts))))], np.zeros(len(texts), dtype=np.int64)

        return modes, labels

    def _singleton_mode(self, texts: List[str], mode_id: int, indices: List[int]) -> FailureMode:
        keywords = []
        if texts:
            # crude keywords: most common tokens
            toks = " ".join(texts).split()
            freq: Dict[str, int] = {}
            for t in toks:
                t = t.strip().lower()
                if len(t) < 3:
                    continue
                freq[t] = freq.get(t, 0) + 1
            keywords = [k for k, _ in sorted(freq.items(), key=lambda x: -x[1])[:8]]
        return FailureMode(mode_id=mode_id, size=len(indices), exemplar_indices=indices[: min(5, len(indices))], keywords=keywords)

    def _keywords_from_texts(self, texts: List[str], topn: int = 8) -> List[str]:
        toks = " ".join(texts).replace(";", " ").replace("|", " ").split()
        freq: Dict[str, int] = {}
        for t in toks:
            t = t.strip().lower()
            if len(t) < 4:
                continue
            if t.startswith("m:") or t.startswith("out:"):
                continue
            freq[t] = freq.get(t, 0) + 1
        return [k for k, _ in sorted(freq.items(), key=lambda x: -x[1])[:topn]]

    # Keeping TF-IDF helpers for optional future use.
