#!/usr/bin/env python3
"""Stage-2 continuous_decay baseline runner (parallel to run_stage2_joint_cbf.py)."""

from __future__ import annotations

import argparse
import json
import logging
import traceback
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import sys
import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from controllers import JointCBFKLController  # noqa: F401
from detectors.base import StepObservation
from detectors.registry import load_detectors_config
from models import BaseLanguageModel
from scripts.run_stage2_joint_cbf import Stage2Runner
from scripts.stage2_utils import distribution_metrics
from scorers import AegisMultiLabelScorer

LOGGER = logging.getLogger(__name__)


@dataclass
class RegenResult:
    tokens: List[int]
    text: str
    records: List[dict]
    interventions: List[dict]
    stage_logs: List[dict]
    ctrl_steps: List[int]
    scale_list: List[float]
    continuous_steps: List[dict]
    per_step_diag: List[dict]
    n_solver_fail: int


class ContinuousDecayRunner(Stage2Runner):
    def __init__(self, *args, **kwargs) -> None:
        self.continuous_steps = kwargs.pop("continuous_steps", None)
        super().__init__(*args, **kwargs)
        self._forcing_scale: Optional[float] = None

    def _apply_forcing_bounds(
        self,
        constraints: Dict[str, object],
        ref_probs: Optional[torch.Tensor],
        policy_context: Optional[Dict[str, object]],
        direction_enabled: bool,
    ) -> Dict[str, object]:
        cfg = self.forcing_cfg
        meta = constraints.get("meta") or {}
        meta["b_used_source"] = meta.get("b_used_source") or "base"
        scale = 1.0 if self._forcing_scale is None else float(self._forcing_scale)
        forcing_meta = {
            "configured": bool(cfg.get("enabled")),
            "applied": False,
            "mode": cfg.get("mode"),
            "rho": float(cfg.get("rho", 0.1)) * scale,
            "eta_min": cfg.get("eta_min"),
            "eta_max": cfg.get("eta_max"),
            "apply_when": cfg.get("apply_when"),
            "reference": "tilde" if direction_enabled else "p",
            "reason": None,
            "rho_scale": scale,
        }
        meta["forcing_info"] = forcing_meta
        if not cfg.get("enabled"):
            forcing_meta["reason"] = "disabled"
            meta["b_used_source"] = "base"
            return forcing_meta
        if cfg.get("apply_when") == "direction_enabled_only" and not direction_enabled:
            forcing_meta["reason"] = "direction_disabled"
            meta["b_used_source"] = "base"
            return forcing_meta
        increments = constraints.get("increments")
        bounds = constraints.get("bounds")
        active_rules = meta.get("active_rules") or []
        if (
            increments is None
            or bounds is None
            or increments.numel() == 0
            or bounds.numel() == 0
            or not active_rules
        ):
            forcing_meta["reason"] = "no_constraints"
            meta["b_used_source"] = "base"
            return forcing_meta
        if ref_probs is None or ref_probs.numel() != increments.shape[1]:
            forcing_meta["reason"] = "invalid_reference"
            meta["b_used_source"] = "base"
            return forcing_meta
        headroom_map = (policy_context or {}).get("headroom") or {}
        row_rule_indices = meta.get("active_rules_rows") or active_rules
        if not row_rule_indices:
            row_rule_indices = active_rules
        headroom_vals: List[float] = []
        for idx in range(increments.shape[0]):
            rule_idx = row_rule_indices[idx] if idx < len(row_rule_indices) else active_rules[idx]
            headroom_vals.append(max(0.0, float(headroom_map.get(rule_idx, 0.0))))
        forcing_meta["headroom_per_rule"] = headroom_vals
        if not any(val > 0 for val in headroom_vals):
            forcing_meta["reason"] = "no_headroom"
            forcing_meta["eta_per_rule"] = [0.0 for _ in headroom_vals]
            forcing_meta["eta_sum"] = 0.0
            forcing_meta["margin_ref_under_bprime"] = []
            forcing_meta["b_ref_dot"] = []
            forcing_meta["clip_hit_low_count"] = 0
            forcing_meta["clip_hit_high_count"] = 0
            meta["b_used_source"] = "base"
            return forcing_meta
        rho = forcing_meta["rho"]
        eta_min = cfg.get("eta_min", 1e-4)
        eta_max = cfg.get("eta_max", 5e-2)
        eta_vals: List[float] = []
        clip_low = 0
        clip_high = 0
        for val in headroom_vals:
            raw = rho * val
            if raw <= 0:
                eta = 0.0
            else:
                eta = raw
                if eta < eta_min:
                    eta = eta_min
                    clip_low += 1
                if eta > eta_max:
                    eta = eta_max
                    clip_high += 1
            eta_vals.append(float(eta))
        if cfg.get("mode") == "top1" and eta_vals:
            best_idx = max(range(len(headroom_vals)), key=lambda i: headroom_vals[i])
            for idx in range(len(eta_vals)):
                if idx != best_idx:
                    eta_vals[idx] = 0.0
        inc64 = increments.to(dtype=torch.float64)
        ref_vec = ref_probs.detach().to(dtype=torch.float64)
        b_ref = torch.mv(inc64, ref_vec)
        eta_tensor = torch.tensor(eta_vals, dtype=torch.float64)
        b_prime = b_ref + eta_tensor
        constraints["bounds"] = b_prime.to(dtype=bounds.dtype)
        meta["b_used_source"] = "forced"
        forcing_meta.update(
            {
                "applied": True,
                "eta_per_rule": [float(val) for val in eta_vals],
                "eta_sum": float(sum(eta_vals)),
                "b_ref_dot": [float(val) for val in b_ref.detach().cpu().tolist()],
                "margin_ref_under_bprime": [
                    float(val) for val in (b_ref - b_prime).detach().cpu().tolist()
                ],
                "clip_hit_low_count": clip_low,
                "clip_hit_high_count": clip_high,
            }
        )
        return forcing_meta

    def _score_prefix_step(
        self,
        prompt: str,
        prefix_text: str,
        token_id: int,
        token_str: str,
        logprob: float,
        metrics: Dict[str, float],
        token_index: int,
        obs_step_index: int,
        prev_margins: Optional[List[float]],
    ) -> Tuple[Optional[StepObservation], Optional[List[float]]]:
        if token_index < self.skip_initial_tokens:
            return None, prev_margins
        try:
            scores = self.scorer.score_prefix(prompt, prefix_text + token_str)
        except Exception as exc:  # pragma: no cover - defensive logging
            LOGGER.warning("Streaming scorer failed at token %s: %s", token_str, exc)
            return None, prev_margins
        margins = [float(scores.margins.get(label, 0.0)) for label in self.labels]
        if prev_margins is None:
            prev = [0.0 for _ in margins]
        else:
            prev = prev_margins
        deltas = [cur - prev_val for cur, prev_val in zip(margins, prev)]
        delta_minus = [max(-delta, 0.0) for delta in deltas]
        probs = [float(scores.probabilities.get(label, 0.0)) for label in self.labels]
        obs = StepObservation(
            step_index=obs_step_index,
            token_id=int(token_id),
            token=token_str,
            h=margins,
            delta_minus=delta_minus,
            probabilities=probs,
            logprob=float(logprob),
            entropy=float(metrics.get("entropy", 0.0)),
            top1_top2_margin=float(metrics.get("top1_top2_margin", 0.0)),
            metadata={
                "token_index": int(token_index),
                "skip_initial_tokens": self.skip_initial_tokens,
            },
            p_any_gate=getattr(scores, "p_any_gate", None),
            p_any_or=getattr(scores, "p_any_or", None),
            logit_any=getattr(scores, "logit_any", None),
            gate_source="any_head",
        )
        obs.tau_any = self.tau_any
        return obs, margins

    @staticmethod
    def _sample_from_probs(
        probs: torch.Tensor,
        top_p: float,
        top_k: int,
    ) -> int:
        if top_k and top_k > 0:
            k = min(int(top_k), probs.numel())
            topk_vals, topk_idx = torch.topk(probs, k=k)
            mask = torch.zeros_like(probs, dtype=torch.bool)
            mask[topk_idx] = True
            probs = torch.where(mask, probs, torch.zeros_like(probs))
            total = float(probs.sum())
            if total > 0:
                probs = probs / total
        if top_p < 1.0:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumulative = torch.cumsum(sorted_probs, dim=-1)
            keep = cumulative <= top_p
            if keep.numel() > 0:
                keep[0] = True
            mask = torch.zeros_like(probs, dtype=torch.bool)
            mask[sorted_idx[keep]] = True
            probs = torch.where(mask, probs, torch.zeros_like(probs))
            total = float(probs.sum())
            if total > 0:
                probs = probs / total
        if float(probs.sum()) <= 0:
            return int(torch.argmax(probs).item())
        return int(torch.multinomial(probs, num_samples=1).item())

    def _run_continuous_decay(
        self,
        prompt: str,
        forced_prefix: Sequence[int],
        baseline_tokens: Sequence[int],
        max_new_tokens: int,
        t_star_token: int,
        t_star_obs: Optional[int],
        L: int,
        k_star: int,
        policy_context_base: Dict[str, object],
        policy_meta: Dict[str, object],
        selected_candidate: dict,
        temperature: float,
        top_p: float,
        top_k: int,
    ) -> RegenResult:
        encoding = self.lm.encode_prompt(prompt)
        input_len = int(encoding.input_ids.shape[-1])
        past = None
        tokens: List[int] = []
        text_buffer = ""
        records: List[dict] = []
        interventions: List[dict] = []
        stage_logs: List[dict] = []
        ctrl_steps: List[int] = []
        scale_list: List[float] = []
        continuous_steps: List[dict] = []
        per_step_diag: List[dict] = []
        n_solver_fail = 0
        prev_margins: Optional[List[float]] = None
        obs_step_index = 0
        forced_len = len(forced_prefix)
        model_max = self.lm.max_positions()
        effective_max_new = int(max_new_tokens)
        if model_max is not None:
            remaining = model_max - input_len
            if remaining < effective_max_new:
                effective_max_new = max(0, remaining)
        log_probs = None
        for step in range(effective_max_new):
            outputs = self.lm.step(
                input_ids=encoding.input_ids,
                attention_mask=encoding.attention_mask,
                past_key_values=past,
            )
            logits = outputs["logits"]
            while logits.dim() > 1:
                logits = logits.squeeze(0)
            past = outputs["past_key_values"]
            metrics = distribution_metrics(logits)
            log_probs = torch.log_softmax(logits, dim=-1)
            prefix_text = text_buffer
            topk_info = self._topk(logits)
            control_active = step >= t_star_token and (step - t_star_token) < L
            if step < forced_len:
                chosen_token_id = int(forced_prefix[step])
                chosen_token_str = self.lm.tokenizer.decode([chosen_token_id], skip_special_tokens=False)
                logprob = float(log_probs[chosen_token_id].item())
                obs, margins = self._score_prefix_step(
                    prompt=prompt,
                    prefix_text=prefix_text,
                    token_id=chosen_token_id,
                    token_str=chosen_token_str,
                    logprob=logprob,
                    metrics=metrics,
                    token_index=len(tokens),
                    obs_step_index=obs_step_index,
                    prev_margins=prev_margins,
                )
                if obs is not None and margins is not None:
                    prev_margins = margins
                    obs_step_index += 1
                records.append(
                    {
                        "index": len(tokens),
                        "token_id": chosen_token_id,
                        "token_str": chosen_token_str,
                        "topk": topk_info,
                        "obs": obs,
                        "prefix_text": prefix_text,
                        "detectors": {},
                    }
                )
                tokens.append(chosen_token_id)
                text_buffer += chosen_token_str
                encoding = self.lm.expand_inputs(encoding, chosen_token_id)
                if chosen_token_id == self.lm.tokenizer.eos_token_id:
                    break
                continue
            if temperature is None or temperature <= 0:
                probs = torch.softmax(logits, dim=-1)
            else:
                probs = torch.softmax(logits / float(temperature), dim=-1)
            base_token_id = self._sample_from_probs(probs, top_p=top_p, top_k=top_k)
            base_token_str = self.lm.tokenizer.decode([base_token_id], skip_special_tokens=False)
            base_logprob = float(log_probs[base_token_id].item())
            base_obs, _ = self._score_prefix_step(
                prompt=prompt,
                prefix_text=prefix_text,
                token_id=base_token_id,
                token_str=base_token_str,
                logprob=base_logprob,
                metrics=metrics,
                token_index=len(tokens),
                obs_step_index=obs_step_index,
                prev_margins=prev_margins,
            )
            chosen_token_id = base_token_id
            chosen_token_str = base_token_str
            step_entry: Optional[dict] = None
            diag_entry: Optional[dict] = None
            if control_active and base_obs is not None:
                i = step - t_star_token
                scale = max(1.0 - (i / float(L)), 1.0 / float(L))
                policy_context = dict(policy_context_base or {})
                policy_context["k_term"] = [int(k_star)]
                headroom, ok = self._estimate_headroom(
                    prompt=prompt,
                    record={
                        "token_id": base_token_id,
                        "token_str": base_token_str,
                        "topk": topk_info,
                        "obs": base_obs,
                        "prefix_text": prefix_text,
                    },
                    candidate_context=None,
                    active_indices=[int(k_star)],
                )
                if ok:
                    policy_context["headroom"] = headroom
                self._forcing_scale = scale
                info = {
                    "record": {
                        "token_id": base_token_id,
                        "token_str": base_token_str,
                        "topk": topk_info,
                        "obs": base_obs,
                        "prefix_text": prefix_text,
                    },
                    "detector": selected_candidate.get("detector"),
                    "detector_info": selected_candidate.get("detector_info"),
                    "step_index": step,
                    "candidate_context": selected_candidate.get("candidate_context") if step == t_star_token else None,
                    "policy": policy_context,
                    "policy_meta": policy_meta,
                }
                ctrl = self._apply_intervention(prompt, info, regen_tokens=0)
                self._forcing_scale = None
                ctrl_steps.append(step)
                scale_list.append(scale)
                interventions.append(ctrl.get("log") or {})
                stage_log = ctrl.get("stage_log")
                if stage_log:
                    stage_log["baseline_name"] = "continuous_decay"
                    if t_star_obs is not None:
                        stage_log["t_star"] = int(t_star_obs)
                    else:
                        stage_log["t_star"] = int(t_star_token)
                    stage_log["t_u"] = policy_context_base.get("t_u_step")
                    stage_log["L"] = int(L)
                    stage_log["k_term_top1"] = int(k_star)
                    stage_log["ctrl_index"] = int(i)
                    stage_log["scale"] = float(scale)
                    stage_logs.append(stage_log)
                chosen_token_id = ctrl.get("token_id")
                if chosen_token_id is None:
                    chosen_token_id = base_token_id
                chosen_token_id = int(chosen_token_id)
                chosen_token_str = self.lm.tokenizer.decode([chosen_token_id], skip_special_tokens=False)
                tv_val = None
                kl_ref = None
                if stage_log:
                    tv_val = stage_log.get("tv_q_p")
                    if tv_val is None:
                        tv_val = stage_log.get("tv")
                    kl_ref = stage_log.get("kl_q_ref")
                base_token_at_step = (
                    int(baseline_tokens[step])
                    if step < len(baseline_tokens)
                    else None
                )
                replaced = int(
                    base_token_at_step is not None
                    and chosen_token_id is not None
                    and int(chosen_token_id) != int(base_token_at_step)
                )
                suffix_now = tokens[t_star_token:] + [int(chosen_token_id)]
                baseline_suffix = (
                    list(baseline_tokens[t_star_token : step + 1])
                    if step < len(baseline_tokens)
                    else []
                )
                min_len = min(len(suffix_now), len(baseline_suffix))
                mismatch = sum(
                    1 for idx in range(min_len) if int(suffix_now[idx]) != int(baseline_suffix[idx])
                )
                replaced_hamming = mismatch + abs(len(suffix_now) - len(baseline_suffix))
                step_entry = {
                    "t": int(step),
                    "scale": float(scale),
                    "tv": None if tv_val is None else float(tv_val),
                    "kl_ref": None if tv_val is not None else (None if kl_ref is None else float(kl_ref)),
                    "hmin": None,
                    "kmin": None,
                    "replaced": int(replaced),
                    "replaced_deprecated": True,
                    "replaced_len": int(len(suffix_now)),
                    "replaced_hamming": int(replaced_hamming),
                    "replaced_hamming_ref": "baseline_suffix",
                }
                continuous_steps.append(step_entry)
                solver_attempted = bool(stage_log and stage_log.get("projection_attempted"))
                solver_feasible = bool(stage_log and stage_log.get("projection_feasible"))
                solver_status = "success" if solver_feasible else "fail"
                if solver_status == "fail":
                    n_solver_fail += 1
                fallback_reason = ""
                if stage_log:
                    fallback_reason = (
                        stage_log.get("fallback_reason")
                        or stage_log.get("infeasible_reason")
                        or stage_log.get("skip_reason")
                        or ""
                    )
                if solver_status == "fail" and not fallback_reason:
                    fallback_reason = "solver_not_called" if not solver_attempted else "solver_infeasible"
                slack_used = None
                if stage_log:
                    slack_used = stage_log.get("slack_used_sum")
                    if slack_used is None:
                        slack_used = stage_log.get("soft_xi_sum_weighted")
                if slack_used is None and not self.slack_enabled:
                    slack_used = 0.0
                rule_ids = []
                if stage_log and isinstance(stage_log.get("active_rules"), list):
                    rule_ids = list(stage_log.get("active_rules") or [])
                rule_idx = None
                if rule_ids:
                    if int(k_star) in rule_ids:
                        rule_idx = rule_ids.index(int(k_star))
                    elif len(rule_ids) == 1:
                        rule_idx = 0
                margin_ref_k = None
                margin_q_k = None
                if stage_log and rule_idx is not None:
                    ref_list = stage_log.get("margin_ref_under_bused") or []
                    if isinstance(ref_list, list) and rule_idx < len(ref_list):
                        margin_ref_k = ref_list[rule_idx]
                    q_list = stage_log.get("constraint_margins_under_q")
                    if q_list is None:
                        q_list = (stage_log.get("margins_detail") or {}).get("q")
                    if isinstance(q_list, list) and rule_idx < len(q_list):
                        margin_q_k = q_list[rule_idx]
                top1_before = None
                top1_after = None
                top1_flip = None
                if stage_log:
                    top1_before = stage_log.get("base_top1_token_id")
                    top1_after = stage_log.get("chosen_token_id")
                    top1_flip = stage_log.get("flip_from_base_top1")
                if top1_flip is None and top1_before is not None and top1_after is not None:
                    top1_flip = bool(int(top1_before) != int(top1_after))
                diag_entry = {
                    "i": int(i),
                    "t": int(step),
                    "scale": float(scale),
                    "solver_status": solver_status,
                    "fallback_reason": fallback_reason,
                    "tv_q_p": None if tv_val is None else float(tv_val),
                    "kl_q_ref": None if tv_val is not None else (None if kl_ref is None else float(kl_ref)),
                    "slack_used": slack_used,
                    "margin_ref_k": margin_ref_k,
                    "margin_q_k": margin_q_k,
                    "top1_before": top1_before,
                    "top1_after": top1_after,
                    "top1_flip": top1_flip,
                    "kmin": None,
                    "hmin": None,
                    "replaced_len": int(len(suffix_now)),
                    "replaced_hamming": int(replaced_hamming),
                    "replaced_hamming_ref": "baseline_suffix",
                }
                per_step_diag.append(diag_entry)
            logprob = float(log_probs[chosen_token_id].item())
            if chosen_token_id == base_token_id and base_obs is not None:
                obs_actual = base_obs
                margins_actual = base_obs.h
            else:
                obs_actual, margins_actual = self._score_prefix_step(
                    prompt=prompt,
                    prefix_text=prefix_text,
                    token_id=chosen_token_id,
                    token_str=chosen_token_str,
                    logprob=logprob,
                    metrics=metrics,
                    token_index=len(tokens),
                    obs_step_index=obs_step_index,
                    prev_margins=prev_margins,
                )
            if obs_actual is not None and margins_actual is not None:
                prev_margins = margins_actual
                obs_step_index += 1
            if step_entry is not None:
                hmin = None
                kmin = None
                if obs_actual is not None and obs_actual.h:
                    hmin = float(min(obs_actual.h))
                    try:
                        kmin = int(min(range(len(obs_actual.h)), key=lambda idx: obs_actual.h[idx]))
                    except Exception:
                        kmin = None
                step_entry["hmin"] = hmin
                step_entry["kmin"] = kmin
            if diag_entry is not None:
                hmin = None
                kmin = None
                if obs_actual is not None and obs_actual.h:
                    hmin = float(min(obs_actual.h))
                    try:
                        kmin = int(min(range(len(obs_actual.h)), key=lambda idx: obs_actual.h[idx]))
                    except Exception:
                        kmin = None
                diag_entry["hmin"] = hmin
                diag_entry["kmin"] = kmin
            records.append(
                {
                    "index": len(tokens),
                    "token_id": chosen_token_id,
                    "token_str": chosen_token_str,
                    "topk": topk_info,
                    "obs": obs_actual,
                    "prefix_text": prefix_text,
                    "detectors": {},
                }
            )
            tokens.append(chosen_token_id)
            text_buffer += chosen_token_str
            encoding = self.lm.expand_inputs(encoding, chosen_token_id)
            if chosen_token_id == self.lm.tokenizer.eos_token_id:
                break
        text = self.lm.decode(tokens) if tokens else ""
        return RegenResult(
            tokens=tokens,
            text=text,
            records=records,
            interventions=interventions,
            stage_logs=stage_logs,
            ctrl_steps=ctrl_steps,
            scale_list=scale_list,
            continuous_steps=continuous_steps,
            per_step_diag=per_step_diag,
            n_solver_fail=n_solver_fail,
        )

    def generate(
        self,
        prompt: str,
        max_new_tokens: int,
        prompt_id: Optional[str] = None,
        prompt_index: Optional[int] = None,
        prompt_meta: Optional[Dict[str, object]] = None,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = 0,
    ) -> dict:
        self._current_prompt_meta = {
            "prompt_id": prompt_id or f"{self.dataset_tag}:{prompt_index}",
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
        }
        if isinstance(prompt_meta, dict):
            if prompt_meta.get("gt_type") is not None:
                self._current_prompt_meta["gt_type"] = prompt_meta.get("gt_type")
            if prompt_meta.get(self.replay_field) is not None:
                self._current_prompt_meta[self.replay_field] = prompt_meta.get(self.replay_field)
            if prompt_meta.get("meta") is not None:
                self._current_prompt_meta["meta"] = prompt_meta.get("meta")
        forced_tokens: List[int] = []
        replay_ctx: Optional[Dict[str, object]] = None
        reference_response = None
        if isinstance(prompt_meta, dict):
            reference_response = prompt_meta.get(self.replay_field)
        if self.replay_reference_response and isinstance(reference_response, str) and reference_response.strip():
            token_ids = self.lm.tokenizer(
                reference_response,
                add_special_tokens=False,
                truncation=True,
                max_length=self.replay_max_ref_tokens if self.replay_max_ref_tokens > 0 else None,
            ).input_ids
            forced_tokens = list(token_ids)
            replay_ctx = {
                "enabled": True,
                "gt_type": self._current_prompt_meta.get("gt_type"),
                "force_arm_on_unsafe": self.force_arm_on_unsafe,
                "replay_backoff_tokens": self.replay_backoff_tokens,
            }
        self._assert_baselines()
        run = self._run_once(prompt, forced_tokens, max_new_tokens, replay_ctx=replay_ctx)
        last_qwen_meta = None
        if self.qwen_hx_enabled:
            last_qwen_meta = run.get("qwen_hx")
        selected_candidate = run.get("intervention")
        policy_meta = run.get("policy_meta") or {}
        policy_context = selected_candidate.get("policy") if isinstance(selected_candidate, dict) else {}
        t_u = policy_meta.get("t_u_step")
        t_star_obs = None
        t_star_token = None
        if isinstance(selected_candidate, dict):
            t_star_obs = policy_context.get("candidate_obs_step")
            if t_star_obs is None:
                t_star_obs = selected_candidate.get("obs_step_index")
            t_star_token = selected_candidate.get("step_index")
        k_star = None
        k_pairs = policy_meta.get("terminal_h_values_for_k_term") or []
        if k_pairs:
            k_star = int(min(k_pairs, key=lambda item: item[1])[0])
        else:
            k_term = policy_context.get("k_term") or policy_meta.get("k_term") or []
            if k_term:
                k_star = int(k_term[0])
        L = 1
        if self.continuous_steps is not None:
            L = max(1, int(self.continuous_steps))
        elif t_star_obs is not None and t_u is not None:
            L = max(1, int(t_u) - int(t_star_obs))
        stage_logs: List[dict] = []
        interventions: List[dict] = []
        ctrl_steps: List[int] = []
        scale_list: List[float] = []
        final_tokens = run.get("tokens") or []
        final_text = run.get("text") or ""
        final_records = run.get("records") or []
        baseline_tokens = list(final_tokens)
        skip_info = run.get("skip_info")
        continuous_steps: List[dict] = []
        per_step_diag: List[dict] = []
        n_solver_fail = 0
        if selected_candidate and k_star is not None and t_star_token is not None and t_u is not None:
            forced_prefix = list(final_tokens[: int(t_star_token)])
            regen = self._run_continuous_decay(
                prompt=prompt,
                forced_prefix=forced_prefix,
                baseline_tokens=baseline_tokens,
                max_new_tokens=max_new_tokens,
                t_star_token=int(t_star_token),
                t_star_obs=int(t_star_obs) if t_star_obs is not None else None,
                L=int(L),
                k_star=int(k_star),
                policy_context_base=dict(policy_context or {}),
                policy_meta=dict(policy_meta or {}),
                selected_candidate=selected_candidate,
                temperature=float(temperature),
                top_p=float(top_p),
                top_k=int(top_k),
            )
            final_tokens = regen.tokens
            final_text = regen.text
            final_records = regen.records
            interventions = regen.interventions
            stage_logs = regen.stage_logs
            ctrl_steps = regen.ctrl_steps
            scale_list = regen.scale_list
            continuous_steps = regen.continuous_steps
            per_step_diag = regen.per_step_diag
            n_solver_fail = regen.n_solver_fail
        stage2_meta = {
            "baseline_name": "continuous_decay",
            "policy_name": self.policy_name,
            "interventions": stage_logs,
            "total_interventions": len(stage_logs),
            "t_star": int(t_star_obs) if t_star_obs is not None else None,
            "t_u": int(t_u) if t_u is not None else None,
            "L_planned": int(L),
            "L": int(L),
            "k_term_top1": int(k_star) if k_star is not None else None,
            "n_ctrl_applied": len(ctrl_steps),
            "n_solver_fail": int(n_solver_fail),
            "ctrl_steps": list(ctrl_steps),
            "scale_list": list(scale_list),
            "skip_info": dict(skip_info) if isinstance(skip_info, dict) else skip_info,
            "records_available": bool(final_records),
            "continuous": {"steps": list(continuous_steps)},
            "per_step_diag": list(per_step_diag),
            "forcing_enabled": bool(self.forcing_cfg.get("enabled")),
            "slack_enabled": bool(self.slack_control_cfg.get("enabled")),
            "soft_projection_enabled": bool(self.soft_projection_enabled),
        }
        result_meta = {
            "sample_id": self._current_prompt_meta.get("prompt_id"),
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
        }
        if self.qwen_hx_enabled:
            t_star_qwen = None
            if isinstance(policy_meta, dict):
                t_star_qwen = policy_meta.get("t_star_qwen")
            t_u_qwen = -1
            hx_history = []
            llm_eval_count = 0
            rollback_reset_count = 0
            fallback_reason = None
            if isinstance(last_qwen_meta, dict):
                t_u_qwen = int(last_qwen_meta.get("t_u_qwen", -1))
                hx_history = last_qwen_meta.get("hx_history") or []
                llm_eval_count = int(last_qwen_meta.get("llm_eval_count", 0) or 0)
                rollback_reset_count = int(last_qwen_meta.get("rollback_reset_count", 0) or 0)
                fallback_reason = last_qwen_meta.get("fallback_reason")
            assert_tstar_lt_tu = None
            if t_star_qwen is not None and t_u_qwen != -1:
                try:
                    assert_tstar_lt_tu = bool(int(t_star_qwen) < int(t_u_qwen))
                except Exception:
                    assert_tstar_lt_tu = None
            result_meta.update(
                {
                    "scorer_used": "qwen_hx",
                    "qwen_hx_enabled": True,
                    "stride": self.qwen_hx_cfg.get("stride"),
                    "eps": self.qwen_hx_cfg.get("eps"),
                    "enter_steps": self.qwen_hx_cfg.get("enter_steps"),
                    "refine": self.qwen_hx_cfg.get("refine"),
                    "refine_radius": self.qwen_hx_cfg.get("refine_radius"),
                    "max_input_tokens": self.qwen_hx_cfg.get("max_input_tokens"),
                    "t_u_qwen": t_u_qwen,
                    "t_star_qwen": int(t_star_qwen) if t_star_qwen is not None else -1,
                    "hx_history": hx_history,
                    "llm_eval_count": llm_eval_count,
                    "rollback_reset_count": rollback_reset_count,
                    "assert_tstar_lt_tu": assert_tstar_lt_tu,
                    "fallback_reason": fallback_reason,
                }
            )
        if self._current_prompt_meta.get("gt_type") is not None:
            result_meta["gt_type"] = self._current_prompt_meta.get("gt_type")
        if isinstance(self._current_prompt_meta.get("meta"), dict):
            result_meta.update(self._current_prompt_meta.get("meta") or {})
        result = {
            "prompt": prompt,
            "prompt_id": self._current_prompt_meta.get("prompt_id"),
            "prompt_index": prompt_index,
            "dataset_tag": self.dataset_tag,
            "generated_text": final_text,
            "token_ids": final_tokens,
            "interventions": interventions,
            "logs": self._serialize_records(final_records),
            "stage2": stage2_meta,
            "meta": result_meta,
        }
        self._current_prompt_meta = None
        return result


def _normalize_gt_type(val: object) -> Optional[str]:
    if val is None:
        return None
    text = str(val).strip().lower()
    if not text:
        return None
    if "harm" in text or "unsafe" in text or "violation" in text:
        return "harmful"
    if "safe" in text:
        return "safe"
    return None


def _normalize_prompt_entry(entry: object) -> Dict[str, object]:
    if isinstance(entry, str):
        return {"prompt": entry, "gt_type": None}
    if isinstance(entry, dict):
        prompt_text = entry.get("prompt") or entry.get("text") or entry.get("input") or ""
        meta = entry.get("meta") or {}
        gt_type = _normalize_gt_type(
            entry.get("gt_type")
            or meta.get("gt_type")
            or entry.get("response_label")
            or entry.get("label")
            or entry.get("safety")
        )
        if gt_type is None:
            vc = entry.get("violated_categories") or entry.get("violations")
            if isinstance(vc, list) and len(vc) > 0:
                gt_type = "harmful"
        if gt_type is None and "is_safe" in entry:
            gt_type = "safe" if bool(entry.get("is_safe")) else "harmful"
        if gt_type is None and "is_harmful" in entry:
            gt_type = "harmful" if bool(entry.get("is_harmful")) else "safe"
        normalized = {"prompt": str(prompt_text), "gt_type": gt_type}
        if "control_paths" in entry:
            normalized["control_paths"] = entry.get("control_paths")
        if "meta" in entry:
            normalized["meta"] = entry.get("meta")
        if "reference_response" in entry:
            normalized["reference_response"] = entry.get("reference_response")
        return normalized
    return {"prompt": str(entry), "gt_type": None}


def _set_seed(seed: int) -> None:
    random.seed(seed)
    try:
        import numpy as np

        np.random.seed(seed)
    except Exception:
        pass
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-2 continuous_decay baseline runner")
    parser.add_argument("--base_model", default="Qwen/Qwen2.5-7B-Instruct")
    parser.add_argument("--prompts_file", required=True)
    parser.add_argument("--output_path", default=None)
    parser.add_argument("--out_dir", default=None)
    parser.add_argument("--max_prompts", type=int, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=64)
    parser.add_argument("--continuous_steps", type=int, default=None)
    parser.add_argument(
        "--policy_name",
        type=str,
        default="score_max",
        choices=["score_max", "minS", "mid_window", "peak_u", "latest_pre_tu"],
    )
    parser.add_argument("--forcing_enabled", type=int, default=1)
    parser.add_argument("--slack_enabled", type=int, default=1)
    parser.add_argument("--detectors_config", required=True)
    parser.add_argument("--controller_config", required=True)
    parser.add_argument("--selected_dims", required=True)
    parser.add_argument("--classifier_path", default="artifacts/aegis_scorer/classifier.pt")
    parser.add_argument("--label_map_path", default="artifacts/aegis_scorer/label_map.json")
    parser.add_argument("--thresholds_path", default="artifacts/aegis_scorer/thresholds.json")
    parser.add_argument("--scorer_dir", help="If set, load encoder/ + classifier_linear.pt + label_map/thresholds from this dir")
    parser.add_argument("--scorer_skip_initial_tokens", type=int, default=4)
    parser.add_argument("--any_gate_enabled", action="store_true", help="use gate_head + tau_any to arm control flow")
    parser.add_argument("--gate_fpr", type=str, choices=["0.05", "0.20"], default="0.05", help="tau_any selection from train_report")
    parser.add_argument("--min_tstar_resp", type=int, default=12, help="min response tokens for t_star clamp")
    parser.add_argument("--refusal_gate_enabled", action="store_true", default=True, help="enable refusal fail-safe gate")
    parser.add_argument("--replay_reference_response", action="store_true", help="replay reference_response prefixes to force t_u")
    parser.add_argument("--replay_field", default="reference_response", help="field name holding reference response")
    parser.add_argument("--replay_max_ref_tokens", type=int, default=192, help="max tokens to replay from reference response")
    parser.add_argument("--replay_backoff_tokens", type=int, default=16, help="t_star fallback backoff from t_u")
    parser.add_argument("--force_arm_on_unsafe", action="store_true", help="force armed in replay for unsafe prompts")
    parser.add_argument("--terminal_thresholds_path", default=None, help="override thresholds for terminal detection")
    parser.add_argument("--kterm_topm", type=int, default=6, help="top-M dims for K_term when triggered")
    parser.add_argument("--kterm_min_prob", type=float, default=0.0, help="min prob threshold for K_term dims")
    parser.add_argument("--debug_kv_align", action="store_true", help="log kv cache alignment diagnostics")
    parser.add_argument("--debug_risk_update", action="store_true", help="log streaming risk update counts")
    parser.add_argument("--device", default=None)
    parser.add_argument("--run_id", default=None)
    parser.add_argument("--dataset_tag", default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
    _set_seed(int(args.seed))

    prompts_raw = json.loads(Path(args.prompts_file).read_text(encoding="utf-8"))
    if isinstance(prompts_raw, dict) and "prompts" in prompts_raw:
        prompts_raw = prompts_raw.get("prompts") or []
    prompts = [_normalize_prompt_entry(entry) for entry in (prompts_raw or [])]
    prompts = [p for p in prompts if isinstance(p.get("prompt"), str) and p.get("prompt").strip()]
    if args.max_prompts is not None:
        prompts = prompts[: args.max_prompts]

    lm = BaseLanguageModel(model_name=args.base_model, device=args.device)
    controller_cfg = json.loads(Path(args.controller_config).read_text(encoding="utf-8"))
    controller_cfg["policy"] = args.policy_name
    forcing_enabled = bool(int(args.forcing_enabled))
    slack_enabled = bool(int(args.slack_enabled))
    forcing_cfg = controller_cfg.get("forcing") or {}
    forcing_cfg["enabled"] = forcing_enabled
    controller_cfg["forcing"] = forcing_cfg
    slack_cfg = controller_cfg.get("slack_control") or {}
    slack_cfg["enabled"] = slack_enabled
    controller_cfg["slack_control"] = slack_cfg
    if not slack_enabled:
        soft_cfg = controller_cfg.get("soft_projection") or {}
        soft_cfg["enabled"] = False
        controller_cfg["soft_projection"] = soft_cfg
    qwen_hx_cfg = controller_cfg.get("qwen_hx") or {}
    qwen_hx_enabled = bool(qwen_hx_cfg.get("enabled", False))

    detectors_cfg = load_detectors_config(args.detectors_config)
    scorer = None
    labels = None
    tau_any = None
    any_gate_enabled = bool(args.any_gate_enabled)
    train_report_path = None
    if qwen_hx_enabled:
        try:
            from llm_judges.qwen_margin import QwenMarginJudge
            from scorers import QwenMarginScorer

            judge_kwargs = {
                "model_path": str(qwen_hx_cfg.get("model_path", "Qwen/Qwen2.5-7B-Instruct")),
                "label_safe": str(qwen_hx_cfg.get("label_safe", " SAFE")),
                "label_unsafe": str(qwen_hx_cfg.get("label_unsafe", " UNSAFE")),
                "use_chat_template": bool(qwen_hx_cfg.get("use_chat_template", False)),
                "max_input_tokens": int(qwen_hx_cfg.get("max_input_tokens", 1024)),
                "dtype": "bfloat16",
                "device_map": "auto",
            }
            if qwen_hx_cfg.get("prompt_template"):
                judge_kwargs["prompt_template"] = qwen_hx_cfg.get("prompt_template")
            judge = QwenMarginJudge(**judge_kwargs)
            scorer = QwenMarginScorer(judge=judge, label="qwen_hx")
            labels = ["qwen_hx"]
            any_gate_enabled = False
            if args.any_gate_enabled:
                LOGGER.info("qwen_hx enabled: disabling any_gate for this run")
            LOGGER.info("scorer_used=qwen_hx (roberta_disabled)")
        except Exception as exc:
            raise RuntimeError(f"Failed to init qwen_hx scorer: {exc}") from exc
    else:
        scorer_encoder = "roberta-base"
        classifier_path = Path(args.classifier_path)
        label_map_path = Path(args.label_map_path)
        thresholds_path = Path(args.thresholds_path)
        gate_head_path = None
        tokenizer_path = None
        if args.scorer_dir:
            scorer_dir = Path(args.scorer_dir)
            scorer_encoder = scorer_dir / "encoder"
            classifier_path = scorer_dir / "classifier_linear.pt"
            label_map_path = scorer_dir / "label_map.json"
            if not label_map_path.exists():
                alt_label = scorer_dir / "label_space.json"
                if alt_label.exists():
                    label_map_path = alt_label
            thresholds_path = scorer_dir / "thresholds.json"
            if not thresholds_path.exists():
                alt_thr = scorer_dir / "thresholds_fpr0p05.json"
                if alt_thr.exists():
                    thresholds_path = alt_thr
            gate_head_path = scorer_dir / "gate_head.pt"
            tokenizer_path = scorer_dir / "tokenizer"
            train_report_path = scorer_dir / "train_report.json"
        else:
            gate_head_path = classifier_path.parent / "gate_head.pt"
            train_report_path = classifier_path.parent / "train_report.json"
        if args.any_gate_enabled:
            if train_report_path is None or not train_report_path.exists():
                raise FileNotFoundError(f"train_report.json not found at {train_report_path}")
        for path_label, path_val in [
            ("classifier_path", classifier_path),
            ("label_map_path", label_map_path),
            ("thresholds_path", thresholds_path),
            ("tokenizer_path", tokenizer_path),
        ]:
            if path_val is None:
                continue
            if not Path(path_val).exists():
                raise FileNotFoundError(f"{path_label} not found at {path_val}")
        if args.any_gate_enabled:
            if gate_head_path is None or not Path(gate_head_path).exists():
                raise FileNotFoundError(f"gate_head_path not found at {gate_head_path}")
            train_report = json.loads(train_report_path.read_text(encoding="utf-8"))
            tau_key = "tau_any_fpr0p05" if args.gate_fpr == "0.05" else "tau_any_fpr0p20"
            if tau_key not in train_report:
                raise KeyError(f"{tau_key} missing in train_report {train_report_path}")
            tau_any = float(train_report[tau_key])
            LOGGER.info(
                "Loaded tau_any=%.6f (key=%s gate_fpr=%s) from %s",
                tau_any,
                tau_key,
                args.gate_fpr,
                train_report_path,
            )
        else:
            tau_any = None
            LOGGER.info("any_gate disabled: skipping gate_head/tau_any checks")

        scorer = AegisMultiLabelScorer(
            encoder_name=str(scorer_encoder),
            classifier_path=classifier_path,
            label_map_path=label_map_path,
            thresholds_path=thresholds_path,
            include_prompt_as_context=False,
            device=args.device,
            gate_head_path=gate_head_path,
            tokenizer_path=tokenizer_path,
        )
        LOGGER.info(
            "scorer loaded gate_head=%s path=%s",
            bool(getattr(scorer, "gate_head_loaded", False)),
            getattr(scorer, "gate_head_path", None),
        )
        LOGGER.info("scorer tokenizer_path=%s", tokenizer_path if tokenizer_path is not None else scorer_encoder)
        selected_payload = json.loads(Path(args.selected_dims).read_text(encoding="utf-8"))
        if isinstance(selected_payload, list):
            labels = selected_payload
        elif isinstance(selected_payload, dict):
            labels = selected_payload.get("selected_dims")
            if labels is None:
                labels = selected_payload.get("labels")
        else:
            raise ValueError(f"Unsupported selected_dims format: {type(selected_payload)}")
        if not labels:
            labels = scorer.labels
    dataset_tag = args.dataset_tag or Path(args.prompts_file).stem
    run_id = args.run_id
    output_path = None
    if args.output_path:
        output_path = Path(args.output_path)
        if run_id is None:
            run_id = output_path.stem
    else:
        if args.out_dir is None:
            raise ValueError("Provide --output_path or --out_dir")
        out_dir = Path(args.out_dir)
        out_dir.mkdir(parents=True, exist_ok=True)
        output_path = out_dir / f"{dataset_tag}.jsonl"
        if run_id is None:
            run_id = out_dir.name

    runner = ContinuousDecayRunner(
        lm=lm,
        scorer=scorer,
        detector_cfg=detectors_cfg,
        controller_cfg=controller_cfg,
        labels=labels,
        controller_config_path=args.controller_config,
        skip_initial_tokens=args.scorer_skip_initial_tokens,
        target_detector=controller_cfg.get("target_detector"),
        run_id=run_id,
        dataset_tag=dataset_tag,
        run_seed=args.seed,
        tau_any=tau_any,
        gate_fpr=args.gate_fpr,
        any_gate_enabled=bool(any_gate_enabled),
        refusal_gate_enabled=bool(args.refusal_gate_enabled),
        min_tstar_resp=args.min_tstar_resp,
        kterm_topm=args.kterm_topm,
        kterm_min_prob=args.kterm_min_prob,
        tau_source_path=str(train_report_path) if train_report_path is not None else None,
        replay_reference_response=bool(args.replay_reference_response),
        replay_field=args.replay_field,
        replay_max_ref_tokens=args.replay_max_ref_tokens,
        replay_backoff_tokens=args.replay_backoff_tokens,
        force_arm_on_unsafe=bool(args.force_arm_on_unsafe),
        terminal_thresholds_path=args.terminal_thresholds_path,
        debug_kv_align=bool(args.debug_kv_align),
        debug_risk_update=bool(args.debug_risk_update),
        continuous_steps=args.continuous_steps,
    )

    output_path.parent.mkdir(parents=True, exist_ok=True)
    run_args = {
        "output_path": str(output_path),
        "controller_config": str(Path(args.controller_config).resolve()),
        "prompts_file": str(Path(args.prompts_file).resolve()),
        "dataset_tag": dataset_tag,
        "run_id": run_id,
        "max_prompts": args.max_prompts,
        "max_new_tokens": args.max_new_tokens,
        "continuous_steps": args.continuous_steps,
        "policy_name": args.policy_name,
        "forcing_enabled": forcing_enabled,
        "slack_enabled": slack_enabled,
        "seed": args.seed,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "top_k": args.top_k,
        "time_utc": datetime.utcnow().isoformat(timespec="seconds"),
    }
    run_args_path = output_path.with_suffix(".run_args.json")
    run_args_path.write_text(json.dumps(run_args, ensure_ascii=False, indent=2), encoding="utf-8")

    with output_path.open("w", encoding="utf-8") as writer:
        for prompt_index, entry in enumerate(prompts):
            prompt = entry.get("prompt")
            gt_type = entry.get("gt_type")
            prompt_identifier = f"{dataset_tag}:{prompt_index}"
            try:
                result = runner.generate(
                    prompt,
                    args.max_new_tokens,
                    prompt_id=prompt_identifier,
                    prompt_index=prompt_index,
                    prompt_meta=entry,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                )
            except Exception as exc:
                debug = {
                    "debug_schema_version": "PRELOOP_DIAG_V3",
                    "exception_type": type(exc).__name__,
                    "exception_msg": str(exc)[:500],
                    "traceback": traceback.format_exc(limit=20)[:2000],
                    "entered_decode_loop": False,
                    "decode_steps": 0,
                    "first_token_written_step": -1,
                    "pre_loop_exit_reason": "exception_pre_loop",
                    "pre_loop_stop_cause": None,
                    "after_tokenize": False,
                    "before_decode_loop": False,
                    "after_decode_loop": False,
                }
                result = {
                    "prompt": prompt,
                    "prompt_id": prompt_identifier,
                    "prompt_index": prompt_index,
                    "dataset_tag": dataset_tag,
                    "generated_text": "",
                    "token_ids": [],
                    "interventions": [],
                    "logs": [],
                    "stage2": {
                        "baseline_name": "continuous_decay",
                        "skip_info": "exception",
                        "records_available": False,
                        "debug": debug,
                    },
                    "meta": {
                        "sample_id": prompt_identifier,
                        "prompt_index": prompt_index,
                        "dataset_tag": dataset_tag,
                    },
                }
            stage2 = result.setdefault("stage2", {})
            stage2_debug = stage2.setdefault("debug", {})
            token_ids = result.get("token_ids") or []
            generated_text = result.get("generated_text") or ""
            after_tokenize_snapshot = {
                "token_ids_out_len": len(token_ids),
                "generated_text_len": len(generated_text),
                "head_text": generated_text[:80],
                "tail_text": generated_text[-80:] if generated_text else "",
            }
            stage2_debug["after_tokenize"] = after_tokenize_snapshot
            if gt_type:
                meta = result.get("meta") or {}
                meta["gt_type"] = gt_type
                result["meta"] = meta
            writer.write(json.dumps(result, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    main()
