"""Heuristic and learned primitive classifier for reasoning trace spans.

Classifies each span into reasoning primitive categories.
- Heuristic: regex/keyword pattern matching (10 legacy labels)
- Learned: 3-way BERT ensemble (9 labels: CHECK replaces VERIFY+ERROR_DETECT,
  SETUP replaces DECOMPOSE)
"""
from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Callable

import numpy as np

from .segmentation import Span

# ---------------------------------------------------------------------------
# Learned ensemble classifier
# ---------------------------------------------------------------------------

_ENSEMBLE_BUNDLES = None  # lazy-loaded

LEARNED_PRIMITIVES = [
    "PLAN", "SETUP", "ENUMERATE", "HYPOTHESIZE", "COMPUTE",
    "CHECK", "BACKTRACK", "SUMMARIZE", "OTHER",
]

_CLASSIFIER_DIR = Path(__file__).parent.parent.parent / \
    "results/exploration_analysis/llm_validation/classifier"

_BUNDLE_NAMES = [
    "classifier_v90_distilled_qwen3_full.joblib",  # best single: V2 puz 0.80, V2 math 0.80
    "classifier_v87_distilled_qwen3_1p7b.joblib",
    "classifier_v68_roberta_v6_seed42.joblib",
    "classifier_v70_deberta_v6.joblib",
    "classifier_v71_electra_v6.joblib",
]


def load_ensemble_classifier(
    batch_size: int = 256,
    multi_gpu: bool = False,
    use_fp16: bool = True,
    n_models: int = 3,
) -> Callable:
    """Load the 3-way BERT ensemble and return a classify function.

    Args:
        batch_size: Inference batch size per GPU. Default 512.
        multi_gpu: If True and 3+ GPUs are available, spread models across GPUs 0/1/2
            and run inference in parallel threads for ~3x speedup.
        use_fp16: Cast models to fp16 for ~2x GPU throughput.

    Returns a callable: fn(rows) -> list[str]
    where rows = [{"span_text": ..., "preceding_context": ...}, ...]
    """
    import torch
    global _ENSEMBLE_BUNDLES
    # Lazy-load only the first n_models bundles. Re-load if a later caller
    # asks for more than we have cached. Avoids requiring every bundle on
    # disk when only the top-k single model (n_models=1, all v90/v87 configs)
    # is needed.
    if _ENSEMBLE_BUNDLES is None or len(_ENSEMBLE_BUNDLES) < n_models:
        import joblib
        _ENSEMBLE_BUNDLES = [
            joblib.load(_CLASSIFIER_DIR / name) for name in _BUNDLE_NAMES[:n_models]
        ]

    active_bundles = _ENSEMBLE_BUNDLES[:n_models]
    if n_models < len(_BUNDLE_NAMES):
        print(f"  Using {n_models}/{len(_BUNDLE_NAMES)} models: {_BUNDLE_NAMES[:n_models]}")

    n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    use_multi = multi_gpu and n_gpus >= len(active_bundles)

    from .llm_validation.classifier.ensemble import predict_proba as _ensemble_predict_proba

    for gpu_id, bundle in enumerate(active_bundles):
        clf = bundle["classifier"]
        clf.batch_size = batch_size
        clf.use_fp16 = use_fp16
        if use_multi:
            clf.device_id = gpu_id
            # Force reload on new device next time
            clf._model = None
            clf._tok = None
            clf._device = None

    label_order = active_bundles[0]["label_order"]

    if use_multi:
        from concurrent.futures import ThreadPoolExecutor
        # Pre-warm models sequentially (from_pretrained not thread-safe)
        for bundle in active_bundles:
            bundle["classifier"]._load()

        def _classify(rows):
            if not rows:
                return []
            # Inference in parallel — CUDA kernels release GIL so threads work
            with ThreadPoolExecutor(max_workers=len(active_bundles)) as ex:
                futures = [ex.submit(_ensemble_predict_proba, b, rows) for b in active_bundles]
                prob_list = [f.result() for f in futures]
            probs = np.mean(prob_list, axis=0)
            idxs = probs.argmax(axis=1)
            return [label_order[i] for i in idxs]
    else:
        def _classify(rows):
            if not rows:
                return []
            probs = np.mean(
                [_ensemble_predict_proba(b, rows) for b in active_bundles], axis=0
            )
            idxs = probs.argmax(axis=1)
            return [label_order[i] for i in idxs]

    return _classify

# ---------------------------------------------------------------------------
# Primitive taxonomy
# ---------------------------------------------------------------------------

PRIMITIVES = [
    "PLAN",
    "DECOMPOSE",
    "ENUMERATE",
    "HYPOTHESIZE",
    "COMPUTE",
    "VERIFY",
    "ERROR_DETECT",
    "BACKTRACK",
    "SUMMARIZE",
    "OTHER",
]

# ---------------------------------------------------------------------------
# Pattern definitions — each primitive gets a list of (pattern, weight) pairs.
# Weight allows stronger signals (multi-word phrases) to count more.
# ---------------------------------------------------------------------------

_RAW_PATTERNS: dict[str, list[tuple[str, float]]] = {
    "PLAN": [
        (r"\blet me plan\b", 2.0),
        (r"\bmy approach\b", 2.0),
        (r"\bmy strategy\b", 2.0),
        (r"\bthe plan is\b", 2.0),
        (r"\bfirst I will\b", 1.5),
        (r"\bI('ll| will) start by\b", 1.5),
        (r"\bhere('s| is) (?:my |the )?(?:plan|approach|strategy)\b", 2.0),
        (r"\blet me outline\b", 1.5),
        (r"\bstep[\s-]by[\s-]step\b", 1.0),
        (r"\boverall approach\b", 1.5),
    ],
    "DECOMPOSE": [
        (r"\bbreak (?:this |it )?(?:down|into)\b", 2.0),
        (r"\bsub[\s-]?problem\b", 2.0),
        (r"\breduce (?:this )?to\b", 1.5),
        (r"\bit suffices to\b", 2.0),
        (r"\bfirst (?:we )?(?:need to )?solve\b", 1.5),
        (r"\bsplit (?:this |the )?(?:into|problem)\b", 1.5),
        (r"\bseparately\b", 1.0),
        (r"\bhandl(?:e|ing) (?:each|the) .{0,20}separately\b", 1.5),
        (r"\bfocus(?:ing)? on\b", 0.8),
        (r"\bconsider (?:the |each )?\w+ (?:part|component|piece)\b", 1.5),
    ],
    "ENUMERATE": [
        (r"\bcase\s+[1-9]\b", 2.0),
        (r"\bcase\s+[A-Z]\b", 1.5),
        (r"\boption\s+[1-9]\b", 1.5),
        (r"\bpossibilit(?:y|ies)\b", 1.5),
        (r"\beither\b.*\bor\b", 1.0),
        (r"\blet('s| us) list\b", 2.0),
        (r"\bthe (?:possible |only )?(?:options|choices|values) are\b", 2.0),
        (r"\benumerat(?:e|ing)\b", 2.0),
        (r"\bfor each\b", 1.0),
        (r"\bconsider (?:all |the )?(?:cases|possibilities|options)\b", 1.5),
        (r"\bthere are \d+ (?:cases|options|possibilities)\b", 1.5),
        (r"\b(?:if|when) .{0,30} (?:then|:)\b.*\b(?:if|when) .{0,30} (?:then|:)\b", 1.0),
    ],
    "HYPOTHESIZE": [
        (r"\bsuppose\b", 1.5),
        (r"\bassume\b", 1.5),
        (r"\bwhat if\b", 2.0),
        (r"\blet('s| us| me) try\b", 1.5),
        (r"\bif we (?:set|assume|suppose|try|let)\b", 1.5),
        (r"\bhypothes(?:is|ize)\b", 2.0),
        (r"\bconjecture\b", 2.0),
        (r"\bguess\b", 1.0),
        (r"\bperhaps\b", 0.8),
        (r"\bmaybe\b", 0.8),
        (r"\btest(?:ing)? (?:the |this |whether )\b", 1.0),
    ],
    "COMPUTE": [
        (r"[=]\s*\d", 1.5),
        (r"\d+\s*[+\-*/]\s*\d+", 1.0),
        (r"\bcalculat(?:e|ing)\b", 1.5),
        (r"\bsolv(?:e|ing)\b", 1.0),
        (r"\bsubstitut(?:e|ing)\b", 1.5),
        (r"\bplugg?(?:ing)? (?:in|into|back)\b", 1.0),
        (r"\bevaluat(?:e|ing)\b", 1.0),
        (r"\bcomput(?:e|ing)\b", 1.5),
        (r"\bsimplif(?:y|ying|ies)\b", 1.0),
        (r"\bexpand(?:ing)?\b", 0.8),
        (r"\bfactor(?:ing|ize)?\b", 0.8),
        (r"\b\d+\s*[/×÷·]\s*\d+\b", 1.0),
        (r"\b\d+\^", 0.8),
        (r"\bmod\s+\d+\b", 1.0),
    ],
    "VERIFY": [
        (r"\bcheck(?:ing)?\b", 1.5),
        (r"\bverif(?:y|ying|ied|ication)\b", 2.0),
        (r"\bsatisf(?:y|ies|ied)\b", 1.5),
        (r"\bconfirm(?:ing|ed|s)?\b", 1.5),
        (r"\bplug(?:ging)? (?:it )?back\b", 2.0),
        (r"\bdoes this (?:work|satisfy|hold)\b", 2.0),
        (r"\blet('s| me| us) (?:verify|check|confirm)\b", 2.0),
        (r"\bmake sure\b", 1.0),
        (r"\bvalidat(?:e|ing)\b", 1.5),
        (r"\bconsistent\b", 1.0),
        (r"\bmatches?\b", 0.8),
        (r"\bcorrect(?:ly|ness)?\b", 0.8),
    ],
    "ERROR_DETECT": [
        (r"\bwait\b", 1.5),
        (r"\bactually\b", 1.0),
        (r"\bthat('s| is) wrong\b", 2.0),
        (r"\bcontradiction\b", 2.0),
        (r"\b(?:this |that |it )(?:doesn't|does not|can't|cannot) work\b", 2.0),
        (r"\bmistake\b", 2.0),
        (r"\berror\b", 1.5),
        (r"\binconsisten(?:t|cy)\b", 2.0),
        (r"\bimpossible\b", 1.5),
        (r"\bwrong\b", 1.0),
        (r"\bcan't be (?:right|correct|true)\b", 2.0),
        (r"\bthis fails\b", 1.5),
        (r"\bdoesn't (?:add up|make sense|hold)\b", 2.0),
        (r"\bhmm\b", 0.8),
        (r"\bbut wait\b", 2.0),
        (r"\bhold on\b", 1.5),
    ],
    "BACKTRACK": [
        (r"\binstead\b", 1.0),
        (r"\brestart(?:ing)?\b", 2.0),
        (r"\btry (?:a )?(?:different|another|new) (?:approach|way|method|path)\b", 2.0),
        (r"\bgo(?:ing)? back\b", 1.5),
        (r"\bstart(?:ing)? over\b", 2.0),
        (r"\blet me reconsider\b", 2.0),
        (r"\babort(?:ing)?\b", 2.0),
        (r"\bscrap(?:ping)? (?:this|that)\b", 2.0),
        (r"\bthis (?:path|approach|method) (?:fails|doesn't work|won't work)\b", 2.0),
        (r"\blet('s| me) try (?:something else|again|a different)\b", 2.0),
        (r"\bback to\b", 1.0),
        (r"\balternativ(?:e|ely)\b", 1.0),
        (r"\bundo\b", 1.5),
        (r"\brethink\b", 1.5),
    ],
    "SUMMARIZE": [
        (r"\btherefore\b", 1.5),
        (r"\bhence\b", 1.5),
        (r"\bthus\b", 1.0),
        (r"\bso the answer\b", 2.0),
        (r"\bin summary\b", 2.0),
        (r"\bwe conclude\b", 2.0),
        (r"\bin conclusion\b", 2.0),
        (r"\bto summarize\b", 2.0),
        (r"\boverall\b", 0.8),
        (r"\bso far\b", 1.0),
        (r"\bwe(?:'ve| have) (?:shown|established|found|determined)\b", 1.5),
        (r"\bthe (?:answer|result|solution) is\b", 2.0),
        (r"\bputting (?:it |this )?(?:all )?together\b", 1.5),
        (r"\bfinal answer\b", 2.0),
    ],
}

# ---------------------------------------------------------------------------
# Compiled patterns
# ---------------------------------------------------------------------------

_COMPILED_PATTERNS: dict[str, list[tuple[re.Pattern, float]]] = {}


def _build_patterns() -> dict[str, list[tuple[re.Pattern, float]]]:
    """Compile all patterns (cached)."""
    if _COMPILED_PATTERNS:
        return _COMPILED_PATTERNS
    for label, patterns in _RAW_PATTERNS.items():
        _COMPILED_PATTERNS[label] = [
            (re.compile(p, re.IGNORECASE), w) for p, w in patterns
        ]
    return _COMPILED_PATTERNS


# ---------------------------------------------------------------------------
# Episode dataclass
# ---------------------------------------------------------------------------

@dataclass
class Episode:
    """A merged sequence of spans with the same primitive label."""
    label: str
    spans: list[Span]
    total_tokens: int

    @property
    def text(self) -> str:
        return "\n".join(s.text for s in self.spans)


# ---------------------------------------------------------------------------
# Classification
# ---------------------------------------------------------------------------

def classify_span(
    span_text: str,
    confidence_threshold: float = 0.01,
) -> tuple[str, float]:
    """Classify a single span into a primitive label.

    Returns (label, confidence). Confidence is the weighted match density
    (total weighted matches / word count). Falls back to OTHER if below threshold.
    """
    patterns = _build_patterns()
    words = span_text.split()
    n_words = max(len(words), 1)

    scores: dict[str, float] = {}
    for label, pats in patterns.items():
        total_weight = 0.0
        for pat, weight in pats:
            matches = pat.findall(span_text)
            total_weight += len(matches) * weight
        scores[label] = total_weight / n_words

    if not scores:
        return "OTHER", 0.0

    best_label = max(scores, key=scores.get)
    best_score = scores[best_label]

    if best_score < confidence_threshold:
        return "OTHER", best_score

    return best_label, best_score


def classify_trace_spans(
    spans: list[Span],
    confidence_threshold: float = 0.01,
    classifier=None,
) -> list[tuple[Span, str, float]]:
    """Classify all spans of a trace.

    Args:
        spans: list of Span objects
        confidence_threshold: for heuristic classifier only
        classifier: if provided, use learned ensemble instead of heuristic.
            Should be a callable: classifier(rows) -> list[str]
            where rows = [{"span_text": ..., "preceding_context": ...}, ...]

    Returns list of (span, label, confidence) tuples.
    """
    if classifier is not None:
        rows = []
        for i, span in enumerate(spans):
            prev_text = spans[i - 1].text if i > 0 else ""
            rows.append({"span_text": span.text, "preceding_context": prev_text})
        labels = classifier(rows)
        return [(span, label, 1.0) for span, label in zip(spans, labels)]

    return [
        (span, *classify_span(span.text, confidence_threshold))
        for span in spans
    ]


def merge_episodes(
    labeled_spans: list[tuple[Span, str, float]],
    min_episode_tokens: int = 30,
) -> list[Episode]:
    """Merge adjacent spans with the same primitive label into episodes.

    Preserves boundaries around BACKTRACK and VERIFY when they differ
    from neighbors. Short spans (< min_episode_tokens) are absorbed
    into the preceding episode if labels match.
    """
    if not labeled_spans:
        return []

    episodes: list[Episode] = []
    current_label = labeled_spans[0][1]
    current_spans = [labeled_spans[0][0]]
    current_tokens = labeled_spans[0][0].n_tokens

    for span, label, conf in labeled_spans[1:]:
        if label == current_label:
            current_spans.append(span)
            current_tokens += span.n_tokens
        else:
            # Emit current episode
            episodes.append(Episode(
                label=current_label,
                spans=current_spans,
                total_tokens=current_tokens,
            ))
            current_label = label
            current_spans = [span]
            current_tokens = span.n_tokens

    # Emit last episode
    episodes.append(Episode(
        label=current_label,
        spans=current_spans,
        total_tokens=current_tokens,
    ))

    # Optional: merge very short episodes into neighbors
    if min_episode_tokens > 0 and len(episodes) > 1:
        merged = [episodes[0]]
        for ep in episodes[1:]:
            if ep.total_tokens < min_episode_tokens and merged[-1].label == ep.label:
                # Absorb into previous
                merged[-1] = Episode(
                    label=merged[-1].label,
                    spans=merged[-1].spans + ep.spans,
                    total_tokens=merged[-1].total_tokens + ep.total_tokens,
                )
            else:
                merged.append(ep)
        episodes = merged

    return episodes


def extract_primitive_sequence(episodes: list[Episode]) -> list[str]:
    """Extract the primitive label sequence from a list of episodes."""
    return [ep.label for ep in episodes]
