# -*- coding: utf-8 -*-
"""
grpo_hh_variance_v1_from_csv.py

- GRPO training (requires TRL >= 0.9.x with GRPOTrainer support).
- Constructs prompts from the first human turn of the Anthropic HH dataset (can be replaced with your own prompt pool).
- Reward: By default, uses only the reward model (z-score + clip[-5,5]), can be combined with format/length rewards.
- After training, retains "joint analysis" artifacts for easy aggregation and comparison with PPO results.

Key points for alignment with the original PPO script:
* Still supports StrongRM (8/4bit quantization), defaults to running on cuda:1.
* Still writes run_meta.json and outputs to joint/{joint_metrics.json,csv}.
* Still allows reading RM metrics like RSI_IQR_med, nGap_med, SEI_med from a CSV for MAD-z composite score calculation.

Note: GRPOTrainer does not expose KL divergence internally; the 'kl' column will be NaN. AUC and slope are calculated based on reward_mean.
"""
import os, re, json, argparse, random, time, math, csv, shutil
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple

import numpy as np
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import dataclasses as dc # Used to get available field names from GRPOConfig

# =========================================================
# Utils (mostly consistent with the PPO version, with minor changes)
# =========================================================
def set_seed_all(seed:int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def leftpad_tokenizer(name:str):
    tok = AutoTokenizer.from_pretrained(name, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None: tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    return tok

def first_human_turn(text:str)->Optional[str]:
    if not isinstance(text,str): return None
    m = re.search(r"Human:\s*(.*?)\n+Assistant:", text, flags=re.S)
    if m: return re.sub(r"\s+"," ", m.group(1).strip())
    return text.strip()[:800]

def get_hh_prompts(n=4096, seed=42):
    ds = load_dataset("Anthropic/hh-rlhf", split="train")
    buf=[]
    for ex in ds:
        p = first_human_turn(ex.get("chosen","")) or first_human_turn(ex.get("rejected",""))
        if p: buf.append(p)
    random.Random(seed).shuffle(buf)
    return buf[:n]

def sanitize(mid: str) -> str:
    return mid.replace("/", "_").replace(":", "_")

def p90_p10(vals):
    if not vals: return float("nan")
    v=np.sort(np.asarray(vals, float)); kH=(len(v)-1)*0.9; kL=(len(v)-1)*0.1
    def lerp(a,k):
        f=int(np.floor(k)); c=int(np.ceil(k))
        return float(a[f] if f==c else a[f]*(c-k)+a[c]*(k-f))
    return lerp(v,kH)-lerp(v,kL)

# =========================================================
# RM Registry and Scorer (consistent with the PPO version)
# =========================================================
RM_ALIASES = {
    "Skywork-Reward-Llama-3.1-8B":      "Skywork/Skywork-Reward-Llama-3.1-8B",
    "Skywork-Reward-V2-Llama-3.1-8B":   "Skywork/Skywork-Reward-Llama-3.1-8B",
    "tulu-v2.5-13b-uf-rm":              "allenai/tulu-v2.5-13b-uf-rm",
    "beaver-7b-v2.0-reward":            "PKU-Alignment/beaver-7b-v2.0-reward",
    "Skywork-Reward-V2-Qwen3-1.7B":     "Skywork/Skywork-Reward-V2-Qwen3-1.7B",
    "Skywork-Reward-V2-Qwen3-8B":       "Skywork/Skywork-Reward-V2-Qwen3-8B",
    "Skywork-Reward-V2-Qwen3-4B":       "Skywork/Skywork-Reward-V2-Qwen3-4B",
    "RM-Mistral-7B":                    "weqweasdas/RM-Mistral-7B",
    "GRM-Llama3-8B-rewardmodel-ft":     "Ray2333/GRM-Llama3-8B-rewardmodel-ft",
    "BTRM_Qwen2_7b_0613":               "CIR-AMS/BTRM_Qwen2_7b_0613",
}
def resolve_rm_name(name:str)->str:
    return RM_ALIASES.get(name, name)

class StrongRM:
    """A wrapper for HuggingFace sequence classification RMs. Outputs z-scored rewards clipped to [-5,5] as float32 tensors."""
    def __init__(self, rm_model:str, device:str="cuda:1", max_len:int=512, label_index:int=0,
                 quant:str="8bit", torch_dtype:str="bfloat16"):
        self.device = device
        self.max_len = max_len
        self.label_index = label_index
        self.tok = AutoTokenizer.from_pretrained(rm_model, use_fast=True, trust_remote_code=True)

        dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
        hf_dtype = dtype_map.get(torch_dtype, torch.bfloat16)

        qconf = None
        if quant == "8bit":
            qconf = BitsAndBytesConfig(load_in_8bit=True)
        elif quant == "4bit":
            qconf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                                       bnb_4bit_quant_type="nf4",
                                       bnb_4bit_compute_dtype=hf_dtype if hf_dtype in (torch.float16, torch.bfloat16) else torch.float16)

        kwargs = dict(trust_remote_code=True)
        if qconf is not None:
            kwargs["quantization_config"] = qconf
            kwargs["device_map"] = {"": device}
        else:
            kwargs["torch_dtype"] = hf_dtype

        self.rm = AutoModelForSequenceClassification.from_pretrained(rm_model, **kwargs)
        if qconf is None: self.rm = self.rm.to(device)
        self.rm.eval()

    @torch.no_grad()
    def score(self, prompts:List[str], responses:List[str])->torch.Tensor:
        texts = [f"Human: {p}\n\nAssistant: {r}" for p,r in zip(prompts, responses)]
        enc = self.tok(texts, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
        out = self.rm(**enc)
        logits = out.logits.squeeze(-1)
        if logits.dim()==2: logits = logits[:, self.label_index]
        mu = logits.mean(); sigma = logits.std(unbiased=False).clamp_min(1e-4)
        rew = ((logits - mu) / sigma).clamp(-5,5)
        return torch.nan_to_num(rew, nan=0.0, posinf=5.0, neginf=-5.0).to(torch.float32)

# =========================================================
# Data: Convert HH prompts to the 'messages' format for GRPOTrainer
# =========================================================
def make_grpo_dataset(prompts: List[str]) -> Dataset:
    """Wraps plain text prompts into the 'messages' format (with an optional system prompt)."""
    system = "You are a helpful, honest and harmless assistant. Answer clearly."
    records = []
    for p in prompts:
        records.append({
            "prompt": [
                {"role":"system", "content": system},
                {"role":"user", "content": p},
            ]
        })
    return Dataset.from_list(records)

# =========================================================
# Reward Functions (for use with GRPOTrainer)
# =========================================================
def _extract_text_from_completions(completions) -> List[str]:
    # completions: List[List[{"content": "...", "role": "assistant"}]]
    outs = []
    for comp in completions:
        if isinstance(comp, list) and comp:
            msg = comp[0]
            if isinstance(msg, dict) and "content" in msg:
                outs.append(str(msg["content"]))
            else:
                outs.append(str(msg))
        else:
            outs.append("")
    return outs

def build_rm_reward_func(rm: StrongRM):
    def rm_reward_func(prompts, completions, **kwargs) -> list[float]:
        # prompts: List[List[{"role":..,"content":..}, ...]], extracts the last user turn.
        prompt_texts = []
        for msgs in prompts:
            ptxt = ""
            try:
                for m in reversed(msgs):
                    if isinstance(m, dict) and m.get("role")=="user":
                        ptxt = m.get("content",""); break
            except Exception:
                pass
            prompt_texts.append(ptxt)
        responses = _extract_text_from_completions(completions)
        with torch.no_grad():
            scores = rm.score(prompt_texts, responses) # returns a tensor of shape [B]
        return [float(x) for x in scores.detach().cpu().tolist()]
    return rm_reward_func

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Optional: Encourages outputs that contain a <reasoning>...</reasoning><answer>...</answer> structure."""
    pattern = re.compile(r"<reasoning>.*?</reasoning>.*?<answer>.*?</answer>", re.S|re.I)
    responses = _extract_text_from_completions(completions)
    return [2.0 if (pattern.search(r) is not None) else 0.0 for r in responses]

def min_length_penalty_func(completions, min_tok: int = 8, penalty: float = -1.0, **kwargs) -> list[float]:
    """Applies a small penalty for overly short responses to prevent model collapse to trivial outputs."""
    responses = _extract_text_from_completions(completions)
    return [0.0 if len(r.split()) >= min_tok else penalty for r in responses]

# =========================================================
# Training & Logging (GRPO)
# =========================================================
@dataclass
class Args:
    # policy / dtype
    model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
    dtype: str = "bfloat16"

    # devices
    use_vllm: int = 0
    policy_device: str = "cuda:0"
    rm_device: str = "cuda:1"

    # rm
    use_rm: int = 1
    rm_name: str = "Skywork-Reward-Llama-3.1-8B"
    rm_quant: str = "8bit" # none|8bit|4bit
    rm_dtype: str = "bfloat16"
    rm_max_len: int = 512
    rm_label_index: int = 0

    # data / sampling
    prompts_n: int = 4096
    max_prompt_tokens: int = 256
    max_completion_tokens: int = 128
    temperature: float = 0.9
    top_p: float = 0.95
    top_k: int = 0
    num_generations: int = 2

    # train
    steps: int = 1000
    per_device_train_batch_size: int = 2
    grad_accum: int = 4
    lr: float = 5e-6
    warmup_ratio: float = 0.1
    weight_decay: float = 0.1
    adam_beta1: float = 0.9
    adam_beta2: float = 0.99
    logging_steps: int = 5

    # misc / io
    seed: int = 123
    save_dir: str = "./rlhf_runs/grpo_hh_v1"
    rm_metrics_csv: str = "./precomputed_metrics/rm_global_metrics.csv"
    joint_root: Optional[str] = None
    slope_head: Optional[int] = None
    auc_norm_to: int = 0

def parse_args() -> Args:
    p = argparse.ArgumentParser()
    g = p.add_argument

    g("--model_name", type=str, default=Args.model_name)
    g("--dtype", type=str, default=Args.dtype, choices=["float16","bfloat16","float32"])
    g("--use_vllm", type=int, default=Args.use_vllm, choices=[0,1])
    g("--policy_device", type=str, default=Args.policy_device)

    g("--use_rm", type=int, default=Args.use_rm)
    g("--rm_name", type=str, default=Args.rm_name)
    g("--rm_device", type=str, default=Args.rm_device)
    g("--rm_quant", type=str, default=Args.rm_quant, choices=["none","8bit","4bit"])
    g("--rm_dtype", type=str, default=Args.rm_dtype, choices=["float16","bfloat16","float32"])
    g("--rm_max_len", type=int, default=Args.rm_max_len)
    g("--rm_label_index", type=int, default=Args.rm_label_index)

    g("--prompts_n", type=int, default=Args.prompts_n)
    g("--max_prompt_tokens", type=int, default=Args.max_prompt_tokens)
    g("--max_completion_tokens", type=int, default=Args.max_completion_tokens)
    g("--temperature", type=float, default=Args.temperature)
    g("--top_p", type=float, default=Args.top_p)
    g("--top_k", type=int, default=Args.top_k)
    g("--num_generations", type=int, default=Args.num_generations)

    g("--steps", type=int, default=Args.steps)
    g("--per_device_train_batch_size", type=int, default=Args.per_device_train_batch_size)
    g("--grad_accum", type=int, default=Args.grad_accum)
    g("--lr", type=float, default=Args.lr)
    g("--warmup_ratio", type=float, default=Args.warmup_ratio)
    g("--weight_decay", type=float, default=Args.weight_decay)
    g("--adam_beta1", type=float, default=Args.adam_beta1)
    g("--adam_beta2", type=float, default=Args.adam_beta2)
    g("--logging_steps", type=int, default=Args.logging_steps)

    g("--seed", type=int, default=Args.seed)
    g("--save_dir", type=str, default=Args.save_dir)

    g("--rm_metrics_csv", type=str, default=Args.rm_metrics_csv)
    g("--joint_root", type=str, default=Args.joint_root)
    g("--slope_head", type=int, default=Args.slope_head)
    g("--auc_norm_to", type=int, default=Args.auc_norm_to)

    ns = p.parse_args()

    # Override with environment variables
    env_rm = os.environ.get("RM_NAME")
    if env_rm: ns.rm_name = env_rm.strip()
    env_vllm = os.environ.get("USE_VLLM")
    if env_vllm is not None:
        try: ns.use_vllm = int(env_vllm)
        except: pass

    return Args(**{k:getattr(ns,k) for k in Args.__annotations__.keys()})

# =========================================================
# Logging and Joint Analysis (adapted from the PPO script)
# =========================================================
def _read_training_log(csv_path:str) -> Dict[str, List[float]]:
    xs, rmean = [], []
    if not os.path.isfile(csv_path): return {"step":xs, "reward_mean":rmean, "kl":[]}
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                xs.append(float(row["step"]))
                rmean.append(float(row["reward_mean"]))
            except Exception:
                continue
    return {"step":xs, "reward_mean":rmean, "kl":[float('nan')]*len(xs)}

def _steps_to_threshold(xs: List[float], ys: List[float], thr: float) -> int:
    for x,y in zip(xs, ys):
        if y >= thr: return int(x)
    return -1

def _trapezoid_auc(xs: List[float], ys: List[float]) -> float:
    if len(xs) < 2: return 0.0
    s = 0.0
    for i in range(1, len(xs)):
        dx = xs[i] - xs[i-1]
        s += (ys[i] + ys[i-1]) * 0.5 * dx
    return float(s)

def _linreg_slope(xs: List[float], ys: List[float]) -> float:
    if len(xs) < 2: return 0.0
    x = np.asarray(xs, float); y = np.asarray(ys, float)
    x = x - x.mean(); y = y - y.mean()
    denom = (x*x).sum()
    if denom <= 1e-12: return 0.0
    return float((x*y).sum() / denom)

def _make_rm_key_candidates(rm_resolved: str) -> List[str]:
    san = sanitize(rm_resolved)
    tail = rm_resolved.split("/")[-1]
    return [f"rb_{san}/{rm_resolved}", f"rb_{san}/{tail}", tail]

def _load_rm_metrics_row(csv_path: str, rm_resolved: str) -> Tuple[Dict[str, str], List[Dict[str, str]]]:
    if not os.path.isfile(csv_path): return {}, []
    all_rows = []
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader: all_rows.append(row)
    cands = _make_rm_key_candidates(rm_resolved)
    hit = {}
    for row in all_rows:
        rk = row.get("rm_key",""); mn = row.get("model_name","")
        if rk in cands or any(rk.endswith("/"+c) for c in cands if "/" not in c) or mn in (rm_resolved, rm_resolved.split("/")[-1]):
            hit = row; break
    return hit, all_rows

def _compute_composite_madz(hit: Dict[str,str], all_rows: List[Dict[str,str]]) -> Dict[str, float]:
    keys = ["RSI_IQR_med", "nGap_med", "SEI_med"]
    cols = {k: [] for k in keys}
    for row in all_rows:
        for k in keys:
            try: cols[k].append(float(row[k]))
            except: pass
    zsum = 0.0; parts = {}
    for k in keys:
        if k not in hit: parts[k]=float("nan"); continue
        try: x = float(hit[k])
        except: parts[k]=float("nan"); continue
        arr = np.asarray(cols[k], float)
        med = float(np.median(arr)) if arr.size else 0.0
        mad = float(np.median(np.abs(arr - med))) if arr.size else 0.0
        scale = (mad * 1.4826) if mad>0 else (float(arr.std()) if arr.size>1 else 1.0)
        z = (x - med) / max(scale, 1e-12)
        parts[k] = float(z); zsum += float(z)
    parts["Composite_madz"] = float(zsum)
    return parts

def post_run_joint_analysis(meta: Dict[str,Any]):
    rm_resolved = meta.get("rm_name_resolved", meta.get("rm_name_input",""))
    csv_train = meta["train_log_csv"]
    rm_metrics_csv = meta["rm_metrics_csv"]
    slope_head = meta.get("slope_head", None)
    auc_norm_to = int(meta.get("auc_norm_to", 0) or 0)

    train = _read_training_log(csv_train)
    xs, rmean = train["step"], train["reward_mean"]

    reward_auc = _trapezoid_auc(xs, rmean)
    reward_auc_norm = (reward_auc / float(auc_norm_to)) if (auc_norm_to and auc_norm_to > 0) else None

    if slope_head is not None and slope_head > 1:
        idx = [i for i,x in enumerate(xs) if x <= slope_head]
        n = (idx[-1]+1) if idx else min(len(xs), 3)
    else:
        n = max(3, int(len(xs) * 0.33))
    slope_early = _linreg_slope(xs[:n], rmean[:n]) if len(xs) >= 2 else 0.0

    final_reward = float(rmean[-1]) if rmean else 0.0

    hit, all_rows = _load_rm_metrics_row(rm_metrics_csv, rm_resolved)
    rm_part = {}
    if hit:
        rm_part = dict(
            rm_key=hit.get("rm_key",""),
            rm_model_name=hit.get("model_name",""),
            RSI_IQR_med=float(hit.get("RSI_IQR_med", "nan")) if hit.get("RSI_IQR_med","")!="" else float("nan"),
            nGap_med=float(hit.get("nGap_med", "nan")) if hit.get("nGap_med","")!="" else float("nan"),
            SEI_med=float(hit.get("SEI_med", "nan")) if hit.get("SEI_med","")!="" else float("nan"),
            nGMD_med=float(hit.get("nGMD_med", "nan")) if hit.get("nGMD_med","")!="" else float("nan"),
        )
        comps = _compute_composite_madz(hit, all_rows)
        rm_part.update(comps)

    out = dict(
        run_dir=meta["save_dir"],
        rm_metrics_csv=rm_metrics_csv,
        **rm_part,
        steps_to_kl80=int(-1), # No KL for GRPO
        steps_to_kl90=int(-1),
        reward_auc=float(reward_auc),
        reward_auc_norm=float(reward_auc_norm) if (reward_auc_norm is not None) else float("nan"),
        reward_slope_early=float(slope_early),
        final_reward_mean=float(final_reward),
        final_kl=float("nan"),
    )

    out_dir = os.path.join(meta["save_dir"], "joint")
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, "joint_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)

    csv_out = os.path.join(out_dir, "joint_metrics.csv")
    with open(csv_out, "w", encoding="utf-8", newline="") as f:
        fieldnames = list(out.keys())
        writer = csv.DictWriter(f, fieldnames=fieldnames); writer.writeheader(); writer.writerow(out)
    print(f"[Joint] saved joint metrics → {csv_out}")

    joint_root = meta.get("joint_root", None)
    if joint_root:
        rm_tail = (rm_resolved.split("/")[-1] if rm_resolved else "unknown_rm")
        base = os.path.basename(meta["save_dir"].rstrip("/"))
        dst_dir = os.path.join(joint_root, rm_tail)
        os.makedirs(dst_dir, exist_ok=True)
        dst_csv = os.path.join(dst_dir, f"{base}.csv")
        try:
            shutil.copy2(csv_out, dst_csv)
            print(f"[Joint] copied to aggregator → {dst_csv}")
        except Exception as e:
            print(f"[Joint] copy to aggregator failed: {e}")

# =========================================================
# Main
# =========================================================
def main():
    try:
        torch.backends.cuda.matmul.allow_tf32=True; torch.backends.cudnn.allow_tf32=True
    except Exception: pass
    os.environ.setdefault("TOKENIZERS_PARALLELISM","false")

    args = parse_args()
    set_seed_all(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)
    train_log_csv = os.path.join(args.save_dir, "train_log.csv")
    if not os.path.exists(train_log_csv):
        with open(train_log_csv, "w", encoding="utf-8") as f:
            f.write("step,reward_mean,reward_std,reward_iqr\n")

    # Prepare data
    prompts = get_hh_prompts(n=args.prompts_n, seed=args.seed)
    ds = make_grpo_dataset(prompts)

    # Tokenizer
    tok = leftpad_tokenizer(args.model_name)

    # Reward Model
    rm = None
    rm_resolved = resolve_rm_name(args.rm_name)
    if int(args.use_rm)==1:
        print(f"[RM] using: {rm_resolved} | quant={args.rm_quant}+{args.rm_dtype} on {args.rm_device}")
        rm = StrongRM(rm_model=rm_resolved, device=args.rm_device, max_len=args.rm_max_len,
                      label_index=args.rm_label_index, quant=args.rm_quant, torch_dtype=args.rm_dtype)

    # reward functions
    reward_funcs = []
    if rm is not None:
        reward_funcs.append(build_rm_reward_func(rm))
    # Can be enabled as needed:
    # reward_funcs.append(soft_format_reward_func)
    # reward_funcs.append(lambda completions, **kw: min_length_penalty_func(completions, min_tok=8, penalty=-0.2))

    # GRPO config (dynamically filters fields unsupported by the current TRL version)
    cfg_kwargs = dict(
        # runtime / io
        output_dir = args.save_dir,
        logging_steps = args.logging_steps,
        report_to = ["none"],

        # optimizer
        learning_rate = args.lr,
        adam_beta1 = args.adam_beta1,
        adam_beta2 = args.adam_beta2,
        weight_decay = args.weight_decay,
        warmup_ratio = args.warmup_ratio,

        # precision
        bf16 = (args.dtype == "bfloat16"),
        fp16 = (args.dtype == "float16"),

        # batching / steps
        per_device_train_batch_size = args.per_device_train_batch_size,
        gradient_accumulation_steps = args.grad_accum,
        max_steps = args.steps,

        # save
        save_strategy = "steps",
        save_steps = max(50, args.logging_steps*10),
        save_safetensors = True,

        # generation
        max_prompt_length = args.max_prompt_tokens,
        max_completion_length = args.max_completion_tokens,
        num_generations = args.num_generations,
        temperature = args.temperature,
        top_p = args.top_p,
        top_k = args.top_k,

        # vLLM (supported in some versions; will be filtered out if not)
        use_vllm = bool(int(args.use_vllm)),
    )
    try:
        available = {f.name for f in dc.fields(GRPOConfig)}
    except Exception:
        available = None
    if available is not None:
        cfg_kwargs = {k: v for k, v in cfg_kwargs.items() if k in available}

    cfg = GRPOConfig(**cfg_kwargs)
    
    # Manually load the model and force SDPA to avoid flash-attn/xformers crashes
    torch_dtype = torch.bfloat16 if args.dtype == "bfloat16" else (torch.float16 if args.dtype == "float16" else torch.float32)
    policy = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        attn_implementation="sdpa", # Key: Disable Flash Attention 2 / xformers
    )
    try:
        policy.config.use_cache = False
    except Exception:
        pass
    policy = policy.to(args.policy_device)

    # Trainer (compatible with different tokenizer argument names across TRL versions)
    try:
        trainer = GRPOTrainer(
            model = policy,
            processing_class = tok,
            reward_funcs = reward_funcs,
            args = cfg,
            train_dataset = ds,
        )
    except TypeError:
        trainer = GRPOTrainer(
            model = policy,
            tokenizer = tok,
            reward_funcs = reward_funcs,
            args = cfg,
            train_dataset = ds,
        )

    # Training
    step = 0
    last_mean = None
    try:
        iterator = trainer.train(yield_steps=True)
    except TypeError:
        iterator = [trainer.train()]
    for _ in iterator:
        # Read the latest log (robustly get the last reward mean from state.log_history)
        hist = [h for h in trainer.state.log_history if isinstance(h, dict)]
        r_mean = None; r_std=None; r_iqr=None
        if hist:
            h = hist[-1]
            # The key in TRL logs might be "rewards/mean" or "reward/mean"
            for k in ["rewards/mean","reward/mean","rewards_mean","reward_mean"]:
                if k in h:
                    r_mean = float(h[k]); break
            for k in ["rewards/std","reward/std","rewards_std","reward_std"]:
                if k in h:
                    r_std = float(h[k]); break
            # IQR may not be present, so it's optional
        if r_mean is not None:
            last_mean = r_mean
            with open(train_log_csv, "a", encoding="utf-8") as f:
                f.write(f"{step},{r_mean:.6f},{(r_std or float('nan')):.6f},{(r_iqr or float('nan')):.6f}\n")
        step += 1

    # Save tokenizer (the model is saved internally by the trainer to output_dir)
    tok.save_pretrained(args.save_dir)

    meta = dict(
        model_name=args.model_name,
        seed=args.seed,
        steps=args.steps,
        per_device_train_batch_size=args.per_device_train_batch_size,
        grad_accum=args.grad_accum,
        lr=args.lr,
        save_dir=args.save_dir,
        train_log_csv=train_log_csv,
        slope_head=args.slope_head,
        auc_norm_to=args.auc_norm_to,
        rm_metrics_csv=args.rm_metrics_csv,
        joint_root=args.joint_root,
        rm_name_input=args.rm_name,
        rm_name_resolved=rm_resolved,
    )
    with open(os.path.join(args.save_dir, "run_meta.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    post_run_joint_analysis(meta)
    print(f"== Done: {args.save_dir} ==")

if __name__ == "__main__":
    main()