#!/usr/bin/env python3
"""Stage-2 MYA hard-filter baseline runner (parallel to continuous_decay runner)."""

from __future__ import annotations

import argparse
import json
import logging
import traceback
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import sys
import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from controllers import MYAHardFilterController
from detectors.base import StepObservation
from detectors.registry import load_detectors_config
from models import BaseLanguageModel
from scripts.run_stage2_joint_cbf import Stage2Runner
from scripts.stage2_utils import distribution_metrics
from scorers import AegisMultiLabelScorer

LOGGER = logging.getLogger(__name__)


@dataclass
class RegenResult:
    tokens: List[int]
    text: str
    records: List[dict]
    interventions: List[dict]
    stage_logs: List[dict]
    ctrl_steps: List[int]
    scale_list: List[float]
    continuous_steps: List[dict]
    per_step_diag: List[dict]


class MYAHardFilterRunner(Stage2Runner):
    def __init__(self, *args, **kwargs) -> None:
        self.continuous_steps = kwargs.pop("continuous_steps", None)
        self.mya_mode = kwargs.pop("mya_mode", "post_rollback_continuous")
        self.kterm_source = kwargs.pop("kterm_source", "obs_kmin")
        super().__init__(*args, **kwargs)

    def _score_prefix_step(
        self,
        prompt: str,
        prefix_text: str,
        token_id: int,
        token_str: str,
        logprob: float,
        metrics: Dict[str, float],
        token_index: int,
        obs_step_index: int,
        prev_margins: Optional[List[float]],
    ) -> Tuple[Optional[StepObservation], Optional[List[float]]]:
        if token_index < self.skip_initial_tokens:
            return None, prev_margins
        try:
            scores = self.scorer.score_prefix(prompt, prefix_text + token_str)
        except Exception as exc:  # pragma: no cover - defensive logging
            LOGGER.warning("Streaming scorer failed at token %s: %s", token_str, exc)
            return None, prev_margins
        margins = [float(scores.margins.get(label, 0.0)) for label in self.labels]
        if prev_margins is None:
            prev = [0.0 for _ in margins]
        else:
            prev = prev_margins
        deltas = [cur - prev_val for cur, prev_val in zip(margins, prev)]
        delta_minus = [max(-delta, 0.0) for delta in deltas]
        probs = [float(scores.probabilities.get(label, 0.0)) for label in self.labels]
        obs = StepObservation(
            step_index=obs_step_index,
            token_id=int(token_id),
            token=token_str,
            h=margins,
            delta_minus=delta_minus,
            probabilities=probs,
            logprob=float(logprob),
            entropy=float(metrics.get("entropy", 0.0)),
            top1_top2_margin=float(metrics.get("top1_top2_margin", 0.0)),
            metadata={
                "token_index": int(token_index),
                "skip_initial_tokens": self.skip_initial_tokens,
            },
            p_any_gate=getattr(scores, "p_any_gate", None),
            p_any_or=getattr(scores, "p_any_or", None),
            logit_any=getattr(scores, "logit_any", None),
            gate_source="any_head",
        )
        obs.tau_any = self.tau_any
        return obs, margins

    @staticmethod
    def _sample_from_probs(
        probs: torch.Tensor,
        top_p: float,
        top_k: int,
    ) -> int:
        if top_k and top_k > 0:
            k = min(int(top_k), probs.numel())
            topk_vals, topk_idx = torch.topk(probs, k=k)
            mask = torch.zeros_like(probs, dtype=torch.bool)
            mask[topk_idx] = True
            probs = torch.where(mask, probs, torch.zeros_like(probs))
            total = float(probs.sum())
            if total > 0:
                probs = probs / total
        if top_p < 1.0:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumulative = torch.cumsum(sorted_probs, dim=-1)
            keep = cumulative <= top_p
            if keep.numel() > 0:
                keep[0] = True
            mask = torch.zeros_like(probs, dtype=torch.bool)
            mask[sorted_idx[keep]] = True
            probs = torch.where(mask, probs, torch.zeros_like(probs))
            total = float(probs.sum())
            if total > 0:
                probs = probs / total
        if float(probs.sum()) <= 0:
            return int(torch.argmax(probs).item())
        return int(torch.multinomial(probs, num_samples=1).item())

    @staticmethod
    def _tv_kl(q_probs: torch.Tensor, p_probs: torch.Tensor) -> Tuple[float, float]:
        q = q_probs.detach().double()
        p = p_probs.detach().double()
        tv = 0.5 * torch.sum(torch.abs(q - p)).item()
        eps = 1e-8
        kl = torch.sum(q * (torch.log(q + eps) - torch.log(p + eps))).item()
        return float(tv), float(kl)

    def _run_mya_filter(
        self,
        prompt: str,
        forced_prefix: Sequence[int],
        baseline_tokens: Sequence[int],
        max_new_tokens: int,
        t_star_token: int,
        t_star_obs: Optional[int],
        L: int,
        k_star: int,
        policy_context_base: Dict[str, object],
        policy_meta: Dict[str, object],
        selected_candidate: dict,
        temperature: float,
        top_p: float,
        top_k: int,
        mya_mode: str,
        kterm_source: str,
    ) -> RegenResult:
        encoding = self.lm.encode_prompt(prompt)
        input_len = int(encoding.input_ids.shape[-1])
        past = None
        tokens: List[int] = []
        text_buffer = ""
        records: List[dict] = []
        interventions: List[dict] = []
        stage_logs: List[dict] = []
        ctrl_steps: List[int] = []
        scale_list: List[float] = []
        continuous_steps: List[dict] = []
        per_step_diag: List[dict] = []
        prev_margins: Optional[List[float]] = None
        obs_step_index = 0
        forced_len = len(forced_prefix)
        model_max = self.lm.max_positions()
        effective_max_new = int(max_new_tokens)
        if model_max is not None:
            remaining = model_max - input_len
            if remaining < effective_max_new:
                effective_max_new = max(0, remaining)
        log_probs = None
        for step in range(effective_max_new):
            outputs = self.lm.step(
                input_ids=encoding.input_ids,
                attention_mask=encoding.attention_mask,
                past_key_values=past,
            )
            logits = outputs["logits"]
            while logits.dim() > 1:
                logits = logits.squeeze(0)
            past = outputs["past_key_values"]
            metrics = distribution_metrics(logits)
            log_probs = torch.log_softmax(logits, dim=-1)
            prefix_text = text_buffer
            topk_info = self._topk(logits)
            control_active = (
                mya_mode == "from_start"
                or (step >= t_star_token and (step - t_star_token) < L)
            )
            if step < forced_len:
                chosen_token_id = int(forced_prefix[step])
                chosen_token_str = self.lm.tokenizer.decode([chosen_token_id], skip_special_tokens=False)
                logprob = float(log_probs[chosen_token_id].item())
                obs, margins = self._score_prefix_step(
                    prompt=prompt,
                    prefix_text=prefix_text,
                    token_id=chosen_token_id,
                    token_str=chosen_token_str,
                    logprob=logprob,
                    metrics=metrics,
                    token_index=len(tokens),
                    obs_step_index=obs_step_index,
                    prev_margins=prev_margins,
                )
                if obs is not None and margins is not None:
                    prev_margins = margins
                    obs_step_index += 1
                records.append(
                    {
                        "index": len(tokens),
                        "token_id": chosen_token_id,
                        "token_str": chosen_token_str,
                        "topk": topk_info,
                        "obs": obs,
                        "prefix_text": prefix_text,
                        "detectors": {},
                    }
                )
                tokens.append(chosen_token_id)
                text_buffer += chosen_token_str
                encoding = self.lm.expand_inputs(encoding, chosen_token_id)
                if chosen_token_id == self.lm.tokenizer.eos_token_id:
                    break
                continue
            if temperature is None or temperature <= 0:
                probs = torch.softmax(logits, dim=-1)
            else:
                probs = torch.softmax(logits / float(temperature), dim=-1)
            base_token_id = self._sample_from_probs(probs, top_p=top_p, top_k=top_k)
            base_token_str = self.lm.tokenizer.decode([base_token_id], skip_special_tokens=False)
            base_logprob = float(log_probs[base_token_id].item())
            base_obs, _ = self._score_prefix_step(
                prompt=prompt,
                prefix_text=prefix_text,
                token_id=base_token_id,
                token_str=base_token_str,
                logprob=base_logprob,
                metrics=metrics,
                token_index=len(tokens),
                obs_step_index=obs_step_index,
                prev_margins=prev_margins,
            )
            chosen_token_id = base_token_id
            chosen_token_str = base_token_str
            step_entry: Optional[dict] = None
            diag_entry: Optional[dict] = None
            if control_active and base_obs is not None:
                if mya_mode == "from_start":
                    i = step
                    scale = 1.0
                else:
                    i = step - t_star_token
                    scale = max(1.0 - (i / float(L)), 1.0 / float(L))
                policy_context = dict(policy_context_base or {})
                k_active = k_star
                if (k_active is None or k_active < 0) and kterm_source == "obs_kmin":
                    if base_obs is not None and base_obs.h:
                        try:
                            k_active = int(min(range(len(base_obs.h)), key=lambda idx: base_obs.h[idx]))
                        except Exception:
                            k_active = None
                if k_active is not None:
                    policy_context["k_term"] = [int(k_active)]
                else:
                    policy_context["k_term"] = []
                ctrl_steps.append(step)
                scale_list.append(scale)
                raw_tokens = topk_info.get("token_ids") or []
                raw_probs = topk_info.get("probs") or []
                raw_logits = topk_info.get("logits") or []
                candidate_tokens, limited_probs, limited_logits = self._select_support_subset(
                    raw_tokens, raw_probs, raw_logits
                )
                topk_ids = candidate_tokens or raw_tokens
                base_top1_id = int(topk_ids[0]) if topk_ids else base_token_id
                base_top1_token = self.lm.tokenizer.decode([base_top1_id], skip_special_tokens=False)
                mya_allowed_size = 0
                mya_empty_set = 1
                if not candidate_tokens or not limited_probs:
                    tv_val = 0.0
                    kl_val = 0.0
                    fallback_reason = "empty_support"
                    top1_before = int(base_top1_id)
                    top1_after = int(base_top1_id)
                    top1_flip = False
                    chosen_token_id = base_token_id
                    chosen_token_str = base_token_str
                else:
                    base_probs_t = torch.tensor(limited_probs, dtype=torch.float32)
                    total = float(base_probs_t.sum())
                    if total > 0:
                        base_probs_t = base_probs_t / total
                    k_active = policy_context.get("k_term") or []
                    k_active = [int(k_active[0])] if k_active else []
                    constraints = self.controller.build_constraints(
                        obs=base_obs,
                        prompt_text=prompt,
                        prefix_text=prefix_text,
                        candidate_tokens=candidate_tokens,
                        active_indices=k_active,
                        candidate_context=selected_candidate.get("candidate_context") if step == t_star_token else None,
                        forced_active_indices=k_active,
                    )
                    result = MYAHardFilterController.apply_from_constraints(base_probs_t, constraints)
                    q_probs_t = result.q_topv
                    mya_allowed_size = int(result.allowed_size)
                    mya_empty_set = int(result.empty_set)
                    tv_val = float(result.tv_q_p)
                    kl_val = float(result.kl_q_p)
                    top1_before = int(candidate_tokens[int(torch.argmax(base_probs_t).item())])
                    top1_after = int(candidate_tokens[int(torch.argmax(q_probs_t).item())])
                    top1_flip = bool(top1_before != top1_after)
                    if float(q_probs_t.sum()) <= 0:
                        idx = int(torch.argmax(q_probs_t).item())
                    else:
                        idx = int(torch.multinomial(q_probs_t, num_samples=1).item())
                    chosen_token_id = int(candidate_tokens[idx])
                    chosen_token_str = self.lm.tokenizer.decode(
                        [chosen_token_id], skip_special_tokens=False
                    )
                    fallback_reason = "empty_set" if mya_empty_set else ""
                base_token_at_step = (
                    int(baseline_tokens[step])
                    if step < len(baseline_tokens)
                    else None
                )
                replaced = int(
                    base_token_at_step is not None
                    and chosen_token_id is not None
                    and int(chosen_token_id) != int(base_token_at_step)
                )
                suffix_now = tokens[t_star_token:] + [int(chosen_token_id)]
                baseline_suffix = (
                    list(baseline_tokens[t_star_token : step + 1])
                    if step < len(baseline_tokens)
                    else []
                )
                min_len = min(len(suffix_now), len(baseline_suffix))
                mismatch = sum(
                    1 for idx in range(min_len) if int(suffix_now[idx]) != int(baseline_suffix[idx])
                )
                replaced_hamming = mismatch + abs(len(suffix_now) - len(baseline_suffix))
                step_entry = {
                    "t": int(step),
                    "scale": float(scale),
                    "tv": float(tv_val),
                    "kl_ref": float(kl_val),
                    "hmin": None,
                    "kmin": None,
                    "replaced_len": int(len(suffix_now)),
                    "replaced_hamming": int(replaced_hamming),
                    "replaced_hamming_ref": "baseline_suffix",
                    "mya_allowed_size": int(mya_allowed_size),
                    "mya_empty_set": int(mya_empty_set),
                }
                continuous_steps.append(step_entry)
                diag_entry = {
                    "i": int(i),
                    "t": int(step),
                    "scale": float(scale),
                    "solver_status": "success",
                    "fallback_reason": fallback_reason,
                    "tv_q_p": float(tv_val),
                    "kl_q_ref": float(kl_val),
                    "slack_used": None,
                    "margin_ref_k": None,
                    "margin_q_k": None,
                    "top1_before": int(top1_before) if top1_before is not None else None,
                    "top1_after": int(top1_after) if top1_after is not None else None,
                    "top1_flip": bool(top1_flip),
                    "kmin": None,
                    "hmin": None,
                    "replaced_len": int(len(suffix_now)),
                    "replaced_hamming": int(replaced_hamming),
                    "replaced_hamming_ref": "baseline_suffix",
                    "mya_allowed_size": int(mya_allowed_size),
                    "mya_empty_set": int(mya_empty_set),
                }
                per_step_diag.append(diag_entry)
                stage_logs.append(
                    {
                        "step": int(step),
                        "control_path": "mya_hard_filter",
                        "base_token": base_token_str,
                        "chosen_token": {"token_str": chosen_token_str},
                        "hmin_before": float(min(base_obs.h)) if base_obs and base_obs.h else None,
                        "hmin_after": None,
                        "delta_hmin": None,
                        "tv_q_p": float(tv_val),
                        "kl_q_ref": float(kl_val),
                        "prefix_text": prefix_text,
                        "fallback_reason": fallback_reason,
                        "mya_allowed_size": int(mya_allowed_size),
                        "mya_empty_set": int(mya_empty_set),
                        "mya_mode": str(mya_mode),
                    }
                )
            logprob = float(log_probs[chosen_token_id].item())
            if chosen_token_id == base_token_id and base_obs is not None:
                obs_actual = base_obs
                margins_actual = base_obs.h
            else:
                obs_actual, margins_actual = self._score_prefix_step(
                    prompt=prompt,
                    prefix_text=prefix_text,
                    token_id=chosen_token_id,
                    token_str=chosen_token_str,
                    logprob=logprob,
                    metrics=metrics,
                    token_index=len(tokens),
                    obs_step_index=obs_step_index,
                    prev_margins=prev_margins,
                )
            if obs_actual is not None and margins_actual is not None:
                prev_margins = margins_actual
                obs_step_index += 1
            if step_entry is not None:
                hmin = None
                kmin = None
                if obs_actual is not None and obs_actual.h:
                    hmin = float(min(obs_actual.h))
                    try:
                        kmin = int(min(range(len(obs_actual.h)), key=lambda idx: obs_actual.h[idx]))
                    except Exception:
                        kmin = None
                step_entry["hmin"] = hmin
                step_entry["kmin"] = kmin
            if diag_entry is not None:
                hmin = None
                kmin = None
                if obs_actual is not None and obs_actual.h:
                    hmin = float(min(obs_actual.h))
                    try:
                        kmin = int(min(range(len(obs_actual.h)), key=lambda idx: obs_actual.h[idx]))
                    except Exception:
                        kmin = None
                diag_entry["hmin"] = hmin
                diag_entry["kmin"] = kmin
            if stage_logs:
                last_log = stage_logs[-1]
                if last_log.get("step") == int(step):
                    last_log["hmin_after"] = diag_entry.get("hmin") if diag_entry else None
                    if last_log.get("hmin_before") is not None and last_log.get("hmin_after") is not None:
                        last_log["delta_hmin"] = last_log["hmin_after"] - last_log["hmin_before"]
            records.append(
                {
                    "index": len(tokens),
                    "token_id": chosen_token_id,
                    "token_str": chosen_token_str,
                    "topk": topk_info,
                    "obs": obs_actual,
                    "prefix_text": prefix_text,
                    "detectors": {},
                }
            )
            tokens.append(chosen_token_id)
            text_buffer += chosen_token_str
            encoding = self.lm.expand_inputs(encoding, chosen_token_id)
            if chosen_token_id == self.lm.tokenizer.eos_token_id:
                break
        text = self.lm.decode(tokens) if tokens else ""
        return RegenResult(
            tokens=tokens,
            text=text,
            records=records,
            interventions=interventions,
            stage_logs=stage_logs,
            ctrl_steps=ctrl_steps,
            scale_list=scale_list,
            continuous_steps=continuous_steps,
            per_step_diag=per_step_diag,
        )

    def generate(
        self,
        prompt: str,
        max_new_tokens: int,
        prompt_id: Optional[str] = None,
        prompt_index: Optional[int] = None,
        prompt_meta: Optional[Dict[str, object]] = None,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = 0,
    ) -> dict:
        self._current_prompt_meta = {
            "prompt_id": prompt_id or f"{self.dataset_tag}:{prompt_index}",
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
        }
        if isinstance(prompt_meta, dict):
            if prompt_meta.get("gt_type") is not None:
                self._current_prompt_meta["gt_type"] = prompt_meta.get("gt_type")
            if prompt_meta.get(self.replay_field) is not None:
                self._current_prompt_meta[self.replay_field] = prompt_meta.get(self.replay_field)
            if prompt_meta.get("meta") is not None:
                self._current_prompt_meta["meta"] = prompt_meta.get("meta")
        forced_tokens: List[int] = []
        replay_ctx: Optional[Dict[str, object]] = None
        reference_response = None
        if isinstance(prompt_meta, dict):
            reference_response = prompt_meta.get(self.replay_field)
        if self.replay_reference_response and isinstance(reference_response, str) and reference_response.strip():
            token_ids = self.lm.tokenizer(
                reference_response,
                add_special_tokens=False,
                truncation=True,
                max_length=self.replay_max_ref_tokens if self.replay_max_ref_tokens > 0 else None,
            ).input_ids
            forced_tokens = list(token_ids)
            replay_ctx = {
                "enabled": True,
                "gt_type": self._current_prompt_meta.get("gt_type"),
                "force_arm_on_unsafe": self.force_arm_on_unsafe,
                "replay_backoff_tokens": self.replay_backoff_tokens,
            }
        self._assert_baselines()
        run = self._run_once(prompt, forced_tokens, max_new_tokens, replay_ctx=replay_ctx)
        last_qwen_meta = None
        if self.qwen_hx_enabled:
            last_qwen_meta = run.get("qwen_hx")
        selected_candidate = run.get("intervention")
        policy_meta = run.get("policy_meta") or {}
        policy_context = selected_candidate.get("policy") if isinstance(selected_candidate, dict) else {}
        t_u = policy_meta.get("t_u_step")
        t_star_obs = None
        t_star_token = None
        if isinstance(selected_candidate, dict):
            t_star_obs = policy_context.get("candidate_obs_step")
            if t_star_obs is None:
                t_star_obs = selected_candidate.get("obs_step_index")
            t_star_token = selected_candidate.get("step_index")
        k_star = None
        k_pairs = policy_meta.get("terminal_h_values_for_k_term") or []
        if k_pairs:
            k_star = int(min(k_pairs, key=lambda item: item[1])[0])
        else:
            k_term = policy_context.get("k_term") or policy_meta.get("k_term") or []
            if k_term:
                k_star = int(k_term[0])
        L = 1
        if self.continuous_steps is not None:
            L = max(1, int(self.continuous_steps))
        elif t_star_obs is not None and t_u is not None:
            L = max(1, int(t_u) - int(t_star_obs))
        stage_logs: List[dict] = []
        interventions: List[dict] = []
        ctrl_steps: List[int] = []
        scale_list: List[float] = []
        final_tokens = run.get("tokens") or []
        final_text = run.get("text") or ""
        final_records = run.get("records") or []
        baseline_tokens = list(final_tokens)
        skip_info = run.get("skip_info")
        continuous_steps: List[dict] = []
        per_step_diag: List[dict] = []
        should_run = False
        if self.mya_mode == "from_start":
            should_run = True
            if t_star_token is None:
                t_star_token = 0
            if t_star_obs is None:
                t_star_obs = 0
            if not selected_candidate:
                selected_candidate = {}
            L = int(max_new_tokens)
        elif selected_candidate and k_star is not None and t_star_token is not None and t_u is not None:
            should_run = True
        if should_run:
            forced_prefix = list(final_tokens[: int(t_star_token)])
            regen = self._run_mya_filter(
                prompt=prompt,
                forced_prefix=forced_prefix,
                baseline_tokens=baseline_tokens,
                max_new_tokens=max_new_tokens,
                t_star_token=int(t_star_token),
                t_star_obs=int(t_star_obs) if t_star_obs is not None else None,
                L=int(L),
                k_star=int(k_star) if k_star is not None else -1,
                policy_context_base=dict(policy_context or {}),
                policy_meta=dict(policy_meta or {}),
                selected_candidate=selected_candidate,
                temperature=float(temperature),
                top_p=float(top_p),
                top_k=int(top_k),
                mya_mode=str(self.mya_mode),
                kterm_source=str(self.kterm_source),
            )
            final_tokens = regen.tokens
            final_text = regen.text
            final_records = regen.records
            interventions = regen.interventions
            stage_logs = regen.stage_logs
            ctrl_steps = regen.ctrl_steps
            scale_list = regen.scale_list
            continuous_steps = regen.continuous_steps
            per_step_diag = regen.per_step_diag
        stage2_meta = {
            "baseline_name": "mya_hard_filter",
            "mya_mode": str(self.mya_mode),
            "interventions": stage_logs,
            "total_interventions": len(stage_logs),
            "t_star": int(t_star_obs) if t_star_obs is not None else None,
            "t_u": int(t_u) if t_u is not None else None,
            "L_planned": int(L),
            "L": int(L),
            "k_term_top1": int(k_star) if k_star is not None else None,
            "n_ctrl_applied": len(ctrl_steps),
            "ctrl_steps": list(ctrl_steps),
            "scale_list": list(scale_list),
            "skip_info": dict(skip_info) if isinstance(skip_info, dict) else skip_info,
            "records_available": bool(final_records),
            "continuous": {"steps": list(continuous_steps)},
            "per_step_diag": list(per_step_diag),
        }
        result_meta = {
            "sample_id": self._current_prompt_meta.get("prompt_id"),
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
        }
        if self.qwen_hx_enabled:
            t_star_qwen = None
            if isinstance(policy_meta, dict):
                t_star_qwen = policy_meta.get("t_star_qwen")
            t_u_qwen = -1
            hx_history = []
            llm_eval_count = 0
            rollback_reset_count = 0
            fallback_reason = None
            if isinstance(last_qwen_meta, dict):
                t_u_qwen = int(last_qwen_meta.get("t_u_qwen", -1))
                hx_history = last_qwen_meta.get("hx_history") or []
                llm_eval_count = int(last_qwen_meta.get("llm_eval_count", 0) or 0)
                rollback_reset_count = int(last_qwen_meta.get("rollback_reset_count", 0) or 0)
                fallback_reason = last_qwen_meta.get("fallback_reason")
            assert_tstar_lt_tu = None
            if t_star_qwen is not None and t_u_qwen != -1:
                try:
                    assert_tstar_lt_tu = bool(int(t_star_qwen) < int(t_u_qwen))
                except Exception:
                    assert_tstar_lt_tu = None
            result_meta.update(
                {
                    "scorer_used": "qwen_hx",
                    "qwen_hx_enabled": True,
                    "stride": self.qwen_hx_cfg.get("stride"),
                    "eps": self.qwen_hx_cfg.get("eps"),
                    "enter_steps": self.qwen_hx_cfg.get("enter_steps"),
                    "refine": self.qwen_hx_cfg.get("refine"),
                    "refine_radius": self.qwen_hx_cfg.get("refine_radius"),
                    "max_input_tokens": self.qwen_hx_cfg.get("max_input_tokens"),
                    "t_u_qwen": t_u_qwen,
                    "t_star_qwen": int(t_star_qwen) if t_star_qwen is not None else -1,
                    "hx_history": hx_history,
                    "llm_eval_count": llm_eval_count,
                    "rollback_reset_count": rollback_reset_count,
                    "assert_tstar_lt_tu": assert_tstar_lt_tu,
                    "fallback_reason": fallback_reason,
                }
            )
        if self._current_prompt_meta.get("gt_type") is not None:
            result_meta["gt_type"] = self._current_prompt_meta.get("gt_type")
        if isinstance(self._current_prompt_meta.get("meta"), dict):
            result_meta.update(self._current_prompt_meta.get("meta") or {})
        result = {
            "prompt": prompt,
            "prompt_id": self._current_prompt_meta.get("prompt_id"),
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
            "generated_text": final_text,
            "token_ids": final_tokens,
            "logs": self._serialize_records(final_records),
            "meta": result_meta,
            "stage2": stage2_meta,
            "interventions": interventions,
        }
        return result


def _normalize_prompt_entry(entry: object) -> dict:
    if isinstance(entry, dict):
        prompt = entry.get("prompt") or entry.get("text") or entry.get("input")
        meta = entry.get("meta") or {}
        return {"prompt": prompt, "meta": meta, "control_paths": entry.get("control_paths")}
    if isinstance(entry, str):
        return {"prompt": entry, "meta": {}, "control_paths": None}
    return {"prompt": "", "meta": {}, "control_paths": None}


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-2 MYA hard-filter baseline runner")
    parser.add_argument("--base_model", default="Qwen/Qwen2.5-7B-Instruct")
    parser.add_argument("--prompts_file", required=True)
    parser.add_argument("--output_path", default=None)
    parser.add_argument("--out_dir", default=None)
    parser.add_argument("--max_prompts", type=int, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=64)
    parser.add_argument("--continuous_steps", type=int, default=None)
    parser.add_argument(
        "--mya_mode",
        choices=["from_start", "post_rollback_continuous"],
        default="post_rollback_continuous",
        help="MYA hard-filter mode (from_start applies at every step; post_rollback applies from t* for N steps)",
    )
    parser.add_argument(
        "--kterm_source",
        choices=["obs_kmin", "policy_context"],
        default="obs_kmin",
        help="Source for K_term when missing (obs_kmin uses current obs.kmin; policy_context uses stored k_term)",
    )
    parser.add_argument("--detectors_config", required=True)
    parser.add_argument("--controller_config", required=True)
    parser.add_argument("--selected_dims", required=True)
    parser.add_argument("--classifier_path", default="artifacts/aegis_scorer/classifier.pt")
    parser.add_argument("--label_map_path", default="artifacts/aegis_scorer/label_map.json")
    parser.add_argument("--thresholds_path", default="artifacts/aegis_scorer/thresholds.json")
    parser.add_argument("--scorer_dir", help="If set, load encoder/ + classifier_linear.pt + label_map/thresholds from this dir")
    parser.add_argument("--scorer_skip_initial_tokens", type=int, default=4)
    parser.add_argument("--any_gate_enabled", action="store_true", help="use gate_head + tau_any to arm control flow")
    parser.add_argument("--gate_fpr", type=str, choices=["0.05", "0.20"], default="0.05", help="tau_any selection from train_report")
    parser.add_argument("--min_tstar_resp", type=int, default=12, help="min response tokens for t_star clamp")
    parser.add_argument("--refusal_gate_enabled", action="store_true", default=True, help="enable refusal fail-safe gate")
    parser.add_argument("--replay_reference_response", action="store_true", help="replay reference_response prefixes to force t_u")
    parser.add_argument("--replay_field", default="reference_response", help="field name holding reference response")
    parser.add_argument("--replay_max_ref_tokens", type=int, default=192, help="max tokens to replay from reference response")
    parser.add_argument("--replay_backoff_tokens", type=int, default=16, help="t_star fallback backoff from t_u")
    parser.add_argument("--force_arm_on_unsafe", action="store_true", help="force armed in replay for unsafe prompts")
    parser.add_argument("--terminal_thresholds_path", default=None, help="override thresholds for terminal detection")
    parser.add_argument("--kterm_topm", type=int, default=6, help="top-M dims for K_term when triggered")
    parser.add_argument("--kterm_min_prob", type=float, default=0.0, help="min prob threshold for K_term dims")
    parser.add_argument("--debug_kv_align", action="store_true", help="log kv cache alignment diagnostics")
    parser.add_argument("--debug_risk_update", action="store_true", help="log streaming risk update counts")
    parser.add_argument("--device", default=None)
    parser.add_argument("--run_id", default=None)
    parser.add_argument("--dataset_tag", default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

    prompts_raw = json.loads(Path(args.prompts_file).read_text(encoding="utf-8"))
    if isinstance(prompts_raw, dict) and "prompts" in prompts_raw:
        prompts_raw = prompts_raw.get("prompts") or []
    prompts = [_normalize_prompt_entry(entry) for entry in (prompts_raw or [])]
    prompts = [p for p in prompts if isinstance(p.get("prompt"), str) and p.get("prompt").strip()]
    if args.max_prompts is not None:
        prompts = prompts[: args.max_prompts]

    lm = BaseLanguageModel(model_name=args.base_model, device=args.device)
    controller_cfg = json.loads(Path(args.controller_config).read_text(encoding="utf-8"))
    qwen_hx_cfg = controller_cfg.get("qwen_hx") or {}
    qwen_hx_enabled = bool(qwen_hx_cfg.get("enabled", False))

    detectors_cfg = load_detectors_config(args.detectors_config)
    scorer = None
    labels = None
    tau_any = None
    any_gate_enabled = bool(args.any_gate_enabled)
    train_report_path = None
    if qwen_hx_enabled:
        try:
            from llm_judges.qwen_margin import QwenMarginJudge
            from scorers import QwenMarginScorer

            judge_kwargs = {
                "model_path": str(qwen_hx_cfg.get("model_path", "Qwen/Qwen2.5-7B-Instruct")),
                "label_safe": str(qwen_hx_cfg.get("label_safe", " SAFE")),
                "label_unsafe": str(qwen_hx_cfg.get("label_unsafe", " UNSAFE")),
                "use_chat_template": bool(qwen_hx_cfg.get("use_chat_template", False)),
                "max_input_tokens": int(qwen_hx_cfg.get("max_input_tokens", 1024)),
                "dtype": "bfloat16",
                "device_map": "auto",
            }
            if qwen_hx_cfg.get("prompt_template"):
                judge_kwargs["prompt_template"] = qwen_hx_cfg.get("prompt_template")
            judge = QwenMarginJudge(**judge_kwargs)
            scorer = QwenMarginScorer(judge=judge, label="qwen_hx")
            labels = ["qwen_hx"]
            any_gate_enabled = False
            if args.any_gate_enabled:
                LOGGER.info("qwen_hx enabled: disabling any_gate for this run")
            LOGGER.info("scorer_used=qwen_hx (roberta_disabled)")
        except Exception as exc:
            raise RuntimeError(f"Failed to init qwen_hx scorer: {exc}") from exc
    else:
        scorer_encoder = "roberta-base"
        classifier_path = Path(args.classifier_path)
        label_map_path = Path(args.label_map_path)
        thresholds_path = Path(args.thresholds_path)
        gate_head_path = None
        tokenizer_path = None
        if args.scorer_dir:
            scorer_dir = Path(args.scorer_dir)
            scorer_encoder = scorer_dir / "encoder"
            classifier_path = scorer_dir / "classifier_linear.pt"
            label_map_path = scorer_dir / "label_map.json"
            if not label_map_path.exists():
                alt_label = scorer_dir / "label_space.json"
                if alt_label.exists():
                    label_map_path = alt_label
            thresholds_path = scorer_dir / "thresholds.json"
            if not thresholds_path.exists():
                alt_thr = scorer_dir / "thresholds_fpr0p05.json"
                if alt_thr.exists():
                    thresholds_path = alt_thr
            gate_head_path = scorer_dir / "gate_head.pt"
            tokenizer_path = scorer_dir / "tokenizer"
            train_report_path = scorer_dir / "train_report.json"
        else:
            gate_head_path = classifier_path.parent / "gate_head.pt"
            train_report_path = classifier_path.parent / "train_report.json"
        if args.any_gate_enabled:
            if train_report_path is None or not train_report_path.exists():
                raise FileNotFoundError(f"train_report.json not found at {train_report_path}")
        for path_label, path_val in [
            ("classifier_path", classifier_path),
            ("label_map_path", label_map_path),
            ("thresholds_path", thresholds_path),
            ("tokenizer_path", tokenizer_path),
        ]:
            if path_val is None:
                continue
            if not Path(path_val).exists():
                raise FileNotFoundError(f"{path_label} not found at {path_val}")
        if args.any_gate_enabled:
            if gate_head_path is None or not Path(gate_head_path).exists():
                raise FileNotFoundError(f"gate_head_path not found at {gate_head_path}")
            train_report = json.loads(train_report_path.read_text(encoding="utf-8"))
            tau_key = "tau_any_fpr0p05" if args.gate_fpr == "0.05" else "tau_any_fpr0p20"
            if tau_key not in train_report:
                raise KeyError(f"{tau_key} missing in train_report {train_report_path}")
            tau_any = float(train_report[tau_key])
            LOGGER.info(
                "Loaded tau_any=%.6f (key=%s gate_fpr=%s) from %s",
                tau_any,
                tau_key,
                args.gate_fpr,
                train_report_path,
            )
        else:
            tau_any = None
            LOGGER.info("any_gate disabled: skipping gate_head/tau_any checks")

        scorer = AegisMultiLabelScorer(
            encoder_name=str(scorer_encoder),
            classifier_path=classifier_path,
            label_map_path=label_map_path,
            thresholds_path=thresholds_path,
            include_prompt_as_context=False,
            device=args.device,
            gate_head_path=gate_head_path,
            tokenizer_path=tokenizer_path,
        )
        LOGGER.info(
            "scorer loaded gate_head=%s path=%s",
            bool(getattr(scorer, "gate_head_loaded", False)),
            getattr(scorer, "gate_head_path", None),
        )
        LOGGER.info("scorer tokenizer_path=%s", tokenizer_path if tokenizer_path is not None else scorer_encoder)
        selected_payload = json.loads(Path(args.selected_dims).read_text(encoding="utf-8"))
        if isinstance(selected_payload, list):
            labels = selected_payload
        elif isinstance(selected_payload, dict):
            labels = selected_payload.get("selected_dims")
            if labels is None:
                labels = selected_payload.get("labels")
        else:
            raise ValueError(f"Unsupported selected_dims format: {type(selected_payload)}")
        if not labels:
            labels = scorer.labels
    dataset_tag = args.dataset_tag or Path(args.prompts_file).stem
    run_id = args.run_id
    output_path = None
    if args.output_path:
        output_path = Path(args.output_path)
    elif args.out_dir:
        output_path = Path(args.out_dir) / f"mya_hard_filter_{dataset_tag}.jsonl"
    if output_path is None:
        raise ValueError("Provide --output_path or --out_dir")
    output_path.parent.mkdir(parents=True, exist_ok=True)
    if args.seed is not None:
        torch.manual_seed(int(args.seed))

    runner = MYAHardFilterRunner(
        lm=lm,
        scorer=scorer,
        detector_cfg=detectors_cfg,
        controller_cfg=controller_cfg,
        labels=list(labels),
        controller_config_path=str(Path(args.controller_config).resolve()),
        skip_initial_tokens=int(args.scorer_skip_initial_tokens),
        target_detector=str(detectors_cfg.get("target_detector") or "delta_hmin"),
        run_id=run_id,
        dataset_tag=dataset_tag,
        run_seed=args.seed,
        tau_any=tau_any,
        gate_fpr=str(args.gate_fpr),
        any_gate_enabled=any_gate_enabled,
        refusal_gate_enabled=bool(args.refusal_gate_enabled),
        min_tstar_resp=int(args.min_tstar_resp),
        kterm_topm=int(args.kterm_topm),
        kterm_min_prob=float(args.kterm_min_prob),
        replay_reference_response=bool(args.replay_reference_response),
        replay_field=str(args.replay_field),
        replay_max_ref_tokens=int(args.replay_max_ref_tokens),
        replay_backoff_tokens=int(args.replay_backoff_tokens),
        force_arm_on_unsafe=bool(args.force_arm_on_unsafe),
        terminal_thresholds_path=args.terminal_thresholds_path,
        debug_kv_align=bool(args.debug_kv_align),
        debug_risk_update=bool(args.debug_risk_update),
        continuous_steps=args.continuous_steps,
        mya_mode=args.mya_mode,
        kterm_source=args.kterm_source,
    )

    with output_path.open("w", encoding="utf-8") as f:
        for idx, entry in enumerate(prompts):
            prompt = entry.get("prompt") or ""
            if not prompt.strip():
                continue
            prompt_meta = entry.get("meta") or {}
            prompt_id = prompt_meta.get("prompt_id") or f"{dataset_tag}:{idx}"
            try:
                result = runner.generate(
                    prompt=prompt,
                    max_new_tokens=int(args.max_new_tokens),
                    prompt_id=prompt_id,
                    prompt_index=idx,
                    prompt_meta=entry,
                    temperature=float(args.temperature),
                    top_p=float(args.top_p),
                    top_k=int(args.top_k),
                )
            except Exception as exc:
                LOGGER.error("Failed prompt %s: %s", prompt_id, exc)
                result = {
                    "prompt": prompt,
                    "prompt_id": prompt_id,
                    "prompt_index": idx,
                    "dataset_tag": dataset_tag,
                    "generated_text": "",
                    "token_ids": [],
                    "logs": [],
                    "meta": {
                        "sample_id": prompt_id,
                        "exception_type": type(exc).__name__,
                        "exception": str(exc),
                        "traceback": traceback.format_exc(limit=3),
                    },
                    "stage2": {
                        "baseline_name": "mya_hard_filter",
                        "interventions": [],
                        "total_interventions": 0,
                        "continuous": {"steps": []},
                        "per_step_diag": [],
                    },
                }
            f.write(json.dumps(result, ensure_ascii=False) + "\n")
            if (idx + 1) % 10 == 0:
                LOGGER.info("Processed %d / %d prompts", idx + 1, len(prompts))
    LOGGER.info("Wrote %s", output_path)


if __name__ == "__main__":
    main()
