# -*- coding: utf-8 -*-
import re
from collections import defaultdict
from typing import Optional, Dict, Any, List, Tuple, Callable

import torch
from verl import DataProto
from verl.workers.reward_manager import register

# ------------------------- 配置与工具 -------------------------

DEFAULT_ALLOWED_COMP_TOKENS = ["<COMP_20>", "<COMP_40>", "<COMP_60>", "<COMP_80>", "<COMP_100>"]
DEFAULT_IGNORE_SPECIALS = [
    "<|im_start|>", "<|im_end|>", "<|assistant|>", "<|user|>",
    "</s>", "<s>"
]

# <COMP_xx>
COMP_RE = re.compile(r"<COMP_(\d{1,3})>")

# \boxed{} / 常见“答案前缀”
FINAL_ANSWER_RE = re.compile(r"(?:the\s+final\s+answer\s+is|final\s+answer|answer)\s*:\s*(.*)$",
                             re.IGNORECASE | re.DOTALL)

# 稳健匹配 <think ...> ... </think>（大小写不敏感，允许属性/空白）
THINK_BLOCK_RE = re.compile(r"(?is)<\s*think\b[^>]*>(.*?)</\s*think\s*>")

# ------------------------- 小工具 -------------------------

def comp_token_to_ratio(tok: str) -> Optional[float]:
    if not tok:
        return None
    m = COMP_RE.fullmatch(tok.strip())
    if not m:
        return None
    val = int(m.group(1))
    return max(0.0, min(1.0, val / 100.0))

def safe_convert_ids(tokenizer, toks: List[str]) -> List[int]:
    out = []
    for t in toks:
        tid = tokenizer.convert_tokens_to_ids(t)
        if tid is not None and tid != getattr(tokenizer, "unk_token_id", None):
            out.append(tid)
    return out

def _token_is_whitespace(tokenizer, tid: int) -> bool:
    try:
        s = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(tid)])
    except Exception:
        s = tokenizer.decode([tid], skip_special_tokens=True)
    s = s or ""
    return s.strip() == ""

def get_first_non_ignored_token_info(tokenizer, response_ids: torch.Tensor, ignore_token_ids: set,
                                     skip_ws: bool = True) -> Tuple[Optional[int], Optional[str], int]:
    ids = response_ids.tolist()
    for idx, tid in enumerate(ids):
        if tid in ignore_token_ids:
            continue
        if skip_ws and _token_is_whitespace(tokenizer, tid):
            continue
        tok_str = tokenizer.convert_ids_to_tokens(tid)
        return tid, tok_str, idx
    return None, None, -1

# ------------------------- \boxed 抽取与等价判断 -------------------------

def _extract_last_boxed_content(text: str) -> Optional[str]:
    if not text:
        return None
    needle = "\\boxed{"
    last = text.rfind(needle)
    if last == -1:
        last = text.rfind(r"\boxed{")
        if last == -1:
            return None
        start = last + len(r"\boxed{")
    else:
        start = last + len(needle)
    depth = 1; i = start
    while i < len(text):
        ch = text[i]
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return text[start:i].strip()
        i += 1
    return None

def extract_final_boxed_answer(text: str) -> Optional[str]:
    boxed = _extract_last_boxed_content(text)
    if boxed is not None:
        return boxed
    m = FINAL_ANSWER_RE.search(text or "")
    if m:
        tail = (m.group(1) or "").strip()
        boxed2 = _extract_last_boxed_content(tail)
        if boxed2 is not None:
            return boxed2
        return tail if tail else None
    return None

def verify_math_equivalence(pred: Optional[str], gt: Optional[str]) -> bool:
    if not pred or not gt:
        return False
    def _clean(s: str) -> str:
        s = (s or "").strip()
        return s.rstrip(".。，；;：:")
    pred_clean = _clean(pred); gt_clean = _clean(gt)
    # 可选严格验证
    try:
        from math_verify import parse, verify  # 如果有外部模块
        p = parse(pred_clean); g = parse(gt_clean)
        return bool(verify(p, g))
    except Exception:
        pass
    # 分数/高精小数兜底
    try:
        from fractions import Fraction
        import decimal
        decimal.getcontext().prec = 50
        def to_number(s: str):
            try:
                return Fraction(s)
            except Exception:
                return decimal.Decimal(s)
        a = to_number(pred_clean); b = to_number(gt_clean)
        try:
            return a == b
        except Exception:
            return abs(a - b) <= decimal.Decimal("1e-12")
    except Exception:
        pass
    return pred_clean.lower() == gt_clean.lower()

# ------------------------- 只数 <think> 段：tiktoken 计数 -------------------------

def build_token_counter(prefer_tiktoken: bool = True,
                        tiktoken_encoding: str = "cl100k_base",
                        hf_tokenizer=None) -> Callable[[str], int]:
    """
    返回一个 count(text)->int 函数。
    优先用 tiktoken；失败时回退到 HF tokenizer；仍失败则退到 len//4。
    """
    enc = None
    if prefer_tiktoken:
        try:
            import tiktoken  # type: ignore
            enc = tiktoken.get_encoding(tiktoken_encoding)
        except Exception:
            enc = None

    if enc is not None:
        def _count(s: str) -> int:
            return len(enc.encode(s or ""))
        return _count

    if hf_tokenizer is not None:
        def _count_hf(s: str) -> int:
            return len(hf_tokenizer.encode(s or "", add_special_tokens=False))
        return _count_hf

    def _fallback(s: str) -> int:
        return max(0, len(s or "") // 4)
    return _fallback

def extract_think_spans(text: str) -> List[Tuple[Tuple[int,int], str]]:
    spans = []
    if not text:
        return spans
    for m in THINK_BLOCK_RE.finditer(text):
        spans.append((m.span(), m.group(1)))
    return spans

def select_think_content(spans: List[Tuple[Tuple[int,int], str]],
                         strategy: str = "last") -> Optional[str]:
    """
    strategy:
      - 'last'    : 取最后一段（靠答案近）
      - 'longest' : 取最长一段
      - 'concat'  : 全部连接
    """
    if not spans:
        return None
    if strategy == "longest":
        return max(spans, key=lambda p: len(p[1]))[1]
    if strategy == "concat":
        return "\n".join([p[1] for p in spans])
    return spans[-1][1]  # default: last

def compute_reason_len_think_only(
    *,
    response_text: str,
    count_tokens: Callable[[str], int],
    think_pick: str = "last",
) -> Tuple[int, bool, str, int]:
    """
    返回：(reason_len, has_full_think, strategy_used, think_blocks_count)
    * 严格只统计 <think>…</think> 内部 token 数；
    * 没有完整 <think> 块时，has_full_think=False，长度记 0（不做其它 fallback）。
    """
    s = response_text or ""
    spans = extract_think_spans(s)
    if not spans:
        return 0, False, think_pick, 0
    inner = select_think_content(spans, think_pick) or ""
    n_tok = count_tokens(inner)
    return int(max(0, n_tok)), True, think_pick, len(spans)

# ------------------------- 奖励形状函数 -------------------------

def huber(delta: float, k: float = 0.1) -> float:
    ad = abs(delta)
    if ad <= k:
        return 1.0 - ad / k
    else:
        return - min(1.0, (ad - k) / max(1e-6, (1.0 - k)))

def band_reward(delta: float, band: float = 0.1) -> float:
    return huber(delta, k=band)

# ------------------------- 选档 → index 工具 -------------------------

def snap_to_bucket(x: Optional[float], buckets: List[float]) -> Optional[float]:
    if x is None:
        return None
    return min(buckets, key=lambda b: abs(b - float(x)))

def bucket_index_map(buckets: List[float]) -> Dict[float, int]:
    return {b: i for i, b in enumerate(buckets)}

# ------------------------- 主奖励构造（只看 <think> + 缺失惩罚） -------------------------

def build_reward_v2(
    *,
    response_text: str,
    response_ids: torch.Tensor,
    tokenizer,
    ground_truth: str,
    allowed_comp_tokens: List[str],
    tokens_full_ref: Optional[int],
    gt_ratio: Optional[float],
    allowed_ratios: Optional[List[float]],
    comp_tolerance: float = 0.07,
    ignore_token_ids: set = frozenset(),
    # 主分权重
    w_acc: float = 0.7,
    w_cal: float = 0.2,
    # “模式项”系数
    mode_gain_correct: float = 0.9,
    mode_penalty_correct_over: float = 0.6,
    mode_penalty_wrong_under: float = 0.7,
    mode_penalty_wrong_over: float = 0.2,
    # 控制头分
    ctrl_hit_gt_reward: float = 0.5,
    ctrl_short_step_gain: float = 0.25,
    ctrl_short_cap: float = 0.95,
    ctrl_confidence: Optional[float] = None,
    # think 选段/缺失惩罚
    think_pick: str = "last",
    require_full_think: bool = True,
    no_think_main_penalty: float = 0.4,
    no_think_ctrl_penalty: float = 0.2,
    # 长度计数（tiktoken）
    count_tokens: Callable[[str], int] = lambda s: len(s or "")
) -> Dict[str, Any]:

    # 1) 首 token（控制头）
    first_tid, first_tok_str, first_pos = get_first_non_ignored_token_info(
        tokenizer, response_ids, ignore_token_ids, skip_ws=True
    )
    first_is_comp = False; chosen_tok = None; chosen_ratio = None
    if first_tok_str is not None and COMP_RE.fullmatch(first_tok_str or ""):
        if first_tok_str in (allowed_comp_tokens or DEFAULT_ALLOWED_COMP_TOKENS):
            first_is_comp = True
            chosen_tok = first_tok_str
            chosen_ratio = comp_token_to_ratio(chosen_tok)

    # 2) 推理长度（严格只看 <think> 段；无 think -> 记 0，并置 has_full_think=False）
    reason_len, has_full_think, think_pick_used, think_blocks = compute_reason_len_think_only(
        response_text=response_text, count_tokens=count_tokens, think_pick=think_pick
    )

    # 3) 参考满长（与你数据口径一致：comp100_think_tokens）
    full_ref = int(tokens_full_ref) if (tokens_full_ref and int(tokens_full_ref) > 0) else max(1, reason_len)
    r_hat = reason_len / float(max(1, full_ref))

    # 4) 正确性
    pred_ans = extract_final_boxed_answer(response_text)
    acc = verify_math_equivalence(pred_ans, ground_truth)
    acc_reward = 1.0 if acc else -1.0

    # 5) 跟随项（校准）
    if chosen_ratio is None:
        cal_reward = -1.0  # 没选合法 <COMP_xx>
    else:
        delta = r_hat - chosen_ratio
        cal_reward = band_reward(delta, band=comp_tolerance)

    # 6) 模式项（相对 GT 档）
    buckets = sorted(set(allowed_ratios or [0.2, 0.4, 0.6, 0.8, 1.0]))
    idx_map = bucket_index_map(buckets)
    gt_snap = snap_to_bucket(gt_ratio, buckets) if gt_ratio is not None else None
    ch_snap = snap_to_bucket(chosen_ratio, buckets) if chosen_ratio is not None else None

    steps_gap = None
    mode_term_main = 0.0
    if (gt_snap is not None) and (ch_snap is not None):
        steps_gap = idx_map[ch_snap] - idx_map[gt_snap]  # <0 更短；>0 更长
        if acc:
            if steps_gap <= 0:
                mode_term_main += mode_gain_correct * min(2, -steps_gap) / 2.0
            else:
                mode_term_main -= mode_penalty_correct_over * steps_gap
        else:
            if steps_gap < 0:
                mode_term_main -= mode_penalty_wrong_under * (-steps_gap)
            elif steps_gap > 0:
                mode_term_main -= mode_penalty_wrong_over * steps_gap

    # 7) 主分 + 缺失 think 惩罚
    score_main_pre = (w_acc * acc_reward) + (w_cal * cal_reward) + mode_term_main
    if require_full_think and not has_full_think:
        score_main_pre -= abs(float(no_think_main_penalty))
    score_main = float(max(-1.0, min(1.0, score_main_pre)))

    # 8) 控制头分（缺失 think 时也额外扣分）
    if (gt_snap is not None) and (ch_snap is not None):
        if acc:
            if steps_gap is not None and steps_gap < 0:
                steps = min(2, -steps_gap)
                conf = float(ctrl_confidence) if ctrl_confidence is not None else 1.0
                conf_boost = 0.5 + 0.5 * max(0.0, min(1.0, conf))
                score_ctrl = ctrl_hit_gt_reward + conf_boost * (ctrl_short_step_gain * steps)
                score_ctrl = min(ctrl_short_cap, score_ctrl)
            elif steps_gap == 0:
                score_ctrl = ctrl_hit_gt_reward
            else:
                score_ctrl = -0.5 * (steps_gap if steps_gap is not None else 0)
        else:
            if steps_gap is not None and steps_gap < 0:
                score_ctrl = -0.7 * (-steps_gap)
            elif steps_gap == 0:
                score_ctrl = -0.2
            else:
                score_ctrl = -0.2 * (steps_gap if steps_gap is not None else 0)
    else:
        score_ctrl = 1.0 if first_is_comp else -1.0
    if require_full_think and not has_full_think:
        score_ctrl -= abs(float(no_think_ctrl_penalty))
    score_ctrl = float(max(-1.0, min(1.0, score_ctrl)))

    return {
        "acc": bool(acc),
        "pred": pred_ans if pred_ans is not None else "[INVALID]",
        "extracted": pred_ans is not None,

        "first_token_id": int(first_tid) if first_tid is not None else -1,
        "first_token_str": first_tok_str or "[NONE]",
        "first_token_pos": int(first_pos),
        "first_token_is_comp": bool(first_is_comp),

        "chosen_comp_token": chosen_tok,
        "chosen_ratio": chosen_ratio if chosen_ratio is not None else -1.0,
        "r_hat": float(r_hat),

        "tokens_reason": int(reason_len),            # 严格只看 <think> 内 token（tiktoken 口径）
        "tokens_full_ref_used": int(full_ref),

        # 组成项
        "acc_reward": float(acc_reward),
        "cal_reward": float(cal_reward if chosen_ratio is not None else -1.0),
        "mode_term_main": float(mode_term_main),

        "gt_ratio": float(gt_snap) if gt_snap is not None else -1.0,
        "steps_gap": int(steps_gap) if steps_gap is not None else 0,

        # 两路分
        "score_main": score_main,
        "score_ctrl": score_ctrl,

        # 诊断/记录
        "has_full_think": bool(has_full_think),
        "think_pick": think_pick_used,
        "think_blocks": int(think_blocks),
        "gt_confidence": float(ctrl_confidence) if ctrl_confidence is not None else -1.0,
    }

# ------------------------- RM 类（两路注入，DeGRPO 风格 + tiktoken 计数） -------------------------

@register("dapo_custom_cdrpo")
class CDRPOStyleRewardManager:
    """
    两路奖励：
      - score_ctrl → 打在首个有效响应 token（“控制头”），仅训练“选档”；
      - score_main → 打在响应序列最后一个有效位置，训练“解题+跟随”。

    改进：
      * 推理长度严格只数 <think> 内 token，且与 ref_full_len 同口径：tiktoken(cl100k_base)；
      * 缺失完整 <think> 时额外惩罚（可配置）；
      * 日志输出可控（避免被截断误判为模型断句）。
    """

    def __init__(
        self,
        tokenizer,
        num_examine: int,
        compute_score=None,
        reward_fn_key: str = "data_source",
        max_resp_len=None,
        overlong_buffer_cfg=None,
        allowed_comp_tokens: Optional[List[str]] = None,
        special_ignore: Optional[List[str]] = None,
        # 主分/模式权重
        w_acc: float = 0.7,
        w_cal: float = 0.2,
        comp_tolerance: float = 0.07,
        mode_gain_correct: float = 0.9,
        mode_penalty_correct_over: float = 0.6,
        mode_penalty_wrong_under: float = 0.7,
        mode_penalty_wrong_over: float = 0.2,
        # 控制头注入 α
        alpha_first_token: float = 0.2,
        # 控制头参数
        ctrl_hit_gt_reward: float = 0.5,
        ctrl_short_step_gain: float = 0.25,
        ctrl_short_cap: float = 0.95,
        # think 相关
        think_pick: str = "last",             # 'last' | 'longest' | 'concat'
        require_full_think: bool = True,      # 缺失 <think> 视为差样本并惩罚
        no_think_main_penalty: float = 0.4,
        no_think_ctrl_penalty: float = 0.2,
        # 计数方式（tiktoken）
        use_tiktoken_for_len: bool = True,
        tiktoken_encoding: str = "cl100k_base",
        # 日志控制
        max_log_chars: int = 0,               # 0 表示不截断；>0 截断到该字符数
    ) -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine
        self.reward_fn_key = reward_fn_key
        self.overlong_buffer_cfg = overlong_buffer_cfg
        self.max_resp_len = max_resp_len

        self.allowed_comp_tokens = allowed_comp_tokens or DEFAULT_ALLOWED_COMP_TOKENS
        self.ignored_specials = special_ignore or DEFAULT_IGNORE_SPECIALS

        ignore_ids = set()
        ignore_ids.update([tid for tid in [
            getattr(tokenizer, "eos_token_id", None),
            getattr(tokenizer, "bos_token_id", None),
            getattr(tokenizer, "pad_token_id", None),
        ] if tid is not None and tid != getattr(tokenizer, "unk_token_id", None)])
        ignore_ids.update(safe_convert_ids(tokenizer, self.ignored_specials))
        self.ignore_token_ids = ignore_ids

        # 权重
        self._w_acc = w_acc
        self._w_cal = w_cal
        self._comp_tolerance = comp_tolerance
        self._mode_gain_correct = mode_gain_correct
        self._mode_penalty_correct_over = mode_penalty_correct_over
        self._mode_penalty_wrong_under = mode_penalty_wrong_under
        self._mode_penalty_wrong_over = mode_penalty_wrong_over

        self._alpha_first = alpha_first_token

        # 控制头参数
        self._ctrl_hit_gt_reward = ctrl_hit_gt_reward
        self._ctrl_short_step_gain = ctrl_short_step_gain
        self._ctrl_short_cap = ctrl_short_cap

        # think / 计数
        self._think_pick = think_pick
        self._require_full_think = require_full_think
        self._no_think_main_penalty = no_think_main_penalty
        self._no_think_ctrl_penalty = no_think_ctrl_penalty

        # tiktoken 计数函数（与数据构造口径一致）
        self._count_tokens = build_token_counter(prefer_tiktoken=use_tiktoken_for_len,
                                                 tiktoken_encoding=tiktoken_encoding,
                                                 hf_tokenizer=self.tokenizer)

        # 日志输出控制
        self._max_log_chars = int(max(0, max_log_chars))

        if self.overlong_buffer_cfg is not None:
            assert self.max_resp_len is not None, "max_resp_len must be provided if overlong_buffer_cfg is set"
            assert self.max_resp_len >= self.overlong_buffer_cfg.len, "max_resp_len must be larger than overlong_buffer.len"

    @staticmethod
    def _maybe_get(ntb: Dict[str, Any], key: str, default=None):
        if key in ntb:
            return ntb[key]
        ext = ntb.get("extra_info", {})
        if isinstance(ext, dict) and key in ext:
            return ext[key]
        return default

    def __call__(self, data: DataProto, return_dict: bool = False):
        if "rm_scores" in data.batch.keys():
            return {"reward_tensor": data.batch["rm_scores"]} if return_dict else data.batch["rm_scores"]

        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        reward_extra_info = defaultdict(list)
        already_print_data_sources = {}

        for i in range(len(data)):
            item = data[i]

            prompt_ids = item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]
            valid_prompt_length = int(item.batch["attention_mask"][:prompt_length].sum())
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = item.batch["responses"]
            valid_response_length = int(item.batch["attention_mask"][prompt_length:].sum())
            valid_response_ids = response_ids[:valid_response_length]

            # 必须 skip_special_tokens=False 才能看到 <COMP_xx>
            prompt_str   = self.tokenizer.decode(valid_prompt_ids,   skip_special_tokens=False)
            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=False)

            eos_token = getattr(self.tokenizer, "eos_token", None)
            if eos_token and response_str.endswith(eos_token):
                response_str = response_str[: -len(eos_token)]

            # # ===================================================== #
            # end_by_eos  = int(eos_token is not None and len(valid_response_ids) > 0 and int(valid_response_ids[-1]) == eos_token)
            # end_by_stop = int(response_str.rstrip().endswith("<|im_end|>")) or int("</s>" in response_str[-20:])

            # # 只打印尾部，避免刷屏
            # prompt_tail = prompt_str[-200:].replace("\n", "\\n")
            # resp_tail   = response_str[-200:].replace("\n", "\\n")

            # print(f"[DBG] split={'val'} "
            #     f"i={i} prompt_len={len(valid_prompt_ids)} resp_len={len(valid_response_ids)} "
            #     f"end_by_eos={end_by_eos} end_by_stop={end_by_stop} "
            #     f"tail_prompt='{prompt_tail}' "
            #     f"tail_resp='{resp_tail}'", flush=True)
            # # ===================================================== #

            ntb = item.non_tensor_batch
            ground_truth = (ntb.get("reward_model", {}).get("ground_truth", "") or "").strip()

            # 参考满长（与数据口径一致：仅 <think>，且用 tiktoken）
            tokens_full_ref = self._maybe_get(ntb, "ref_full_len") or self._maybe_get(ntb, "orig_cot_tokens") or None

            # 允许档与带宽
            allowed_ratios  = self._maybe_get(ntb, "allowed_ratios") or [0.2, 0.4, 0.6, 0.8, 1.0]
            comp_tol        = float(self._maybe_get(ntb, "comp_tolerance", self._comp_tolerance))
            allowed         = ntb.get("allowed_comp_tokens", None) or self.allowed_comp_tokens

            # GT ratio：优先 gt_ratio → r_star → chosen_ratio
            gt_ratio = self._maybe_get(ntb, "gt_ratio", None)
            if gt_ratio is None:
                gt_ratio = self._maybe_get(ntb, "r_star", None)
            if gt_ratio is None:
                gt_ratio = self._maybe_get(ntb, "chosen_ratio", None)

            # r* 置信度（可选）
            gt_conf = self._maybe_get(ntb, "gt_confidence", None)

            result = build_reward_v2(
                response_text=response_str,
                response_ids=valid_response_ids,
                tokenizer=self.tokenizer,
                ground_truth=ground_truth,
                allowed_comp_tokens=allowed,
                tokens_full_ref=tokens_full_ref,
                gt_ratio=gt_ratio,
                allowed_ratios=allowed_ratios,
                comp_tolerance=comp_tol,
                ignore_token_ids=self.ignore_token_ids,
                w_acc=self._w_acc,
                w_cal=self._w_cal,
                mode_gain_correct=self._mode_gain_correct,
                mode_penalty_correct_over=self._mode_penalty_correct_over,
                mode_penalty_wrong_under=self._mode_penalty_wrong_under,
                mode_penalty_wrong_over=self._mode_penalty_wrong_over,
                ctrl_hit_gt_reward=self._ctrl_hit_gt_reward,
                ctrl_short_step_gain=self._ctrl_short_step_gain,
                ctrl_short_cap=self._ctrl_short_cap,
                ctrl_confidence=gt_conf,
                think_pick=self._think_pick,
                require_full_think=self._require_full_think,
                no_think_main_penalty=self._no_think_main_penalty,
                no_think_ctrl_penalty=self._no_think_ctrl_penalty,
                count_tokens=self._count_tokens,   # <- 统一口径
            )

            # 1) 末位主分
            main_reward = float(result["score_main"])
            reward = main_reward

            # 2) 首位控制头
            first_pos = int(result["first_token_pos"])
            if 0 <= first_pos < int(valid_response_length):
                ctrl_reward = float(result["score_ctrl"]) * float(self._alpha_first)
                reward_tensor[i, first_pos] += ctrl_reward
                result["ctrl_injected"] = True
                result["ctrl_reward"] = float(ctrl_reward)
            else:
                result["ctrl_injected"] = False
                result["ctrl_reward"] = 0.0

            # （可选）超长保护
            if self.overlong_buffer_cfg and getattr(self.overlong_buffer_cfg, "enable", False):
                overlong_buffer_len = self.overlong_buffer_cfg.len
                expected_len = self.max_resp_len - overlong_buffer_len
                exceed_len = int(valid_response_length) - int(expected_len)
                overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
                reward += overlong_reward
                result["overlong_reward"] = float(overlong_reward)
                result["overlong"] = bool(overlong_reward < 0)

            # 把主分写在末 token
            if valid_response_length > 0:
                reward_tensor[i, int(valid_response_length) - 1] += float(reward)
            result["score"] = float(reward)  # 兼容旧字段

            for k, v in result.items():
                reward_extra_info[k].append(v)

            # 可控打印（不再强制 4000 截断）
            ds = ntb.get(self.reward_fn_key, "unknown_source")
            if ds not in already_print_data_sources:
                already_print_data_sources[ds] = 0
            if already_print_data_sources[ds] < self.num_examine:
                already_print_data_sources[ds] += 1
                print("[AUTO/STRICT v2 | think-only(tiktoken)]")
                if self._max_log_chars > 0:
                    print("[prompt]", (prompt_str[:self._max_log_chars] + ("…[TRUNC]" if len(prompt_str) > self._max_log_chars else "")))
                    print("[response]", (response_str[:self._max_log_chars] + ("…[TRUNC]" if len(response_str) > self._max_log_chars else "")))
                else:
                    print("[prompt]", prompt_str)
                    print("[response]", response_str)
                print("[ground_truth]", ground_truth)
                for k, v in result.items():
                    print(f"[{k}]", v)

        return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} if return_dict else reward_tensor
