# -*- coding: utf-8 -*-
"""
ppo_hh_variance_v7_joint_from_csv.py

- PPO training (TRL 0.9.x).
- RM reward: z-score + clip[-5,5].
- Robust KL/Entropy: token-level calculation to avoid NaN.
- Performs "joint analysis" only after training is complete:
    * Reads RM metrics from a CSV (e.g., RSI_IQR_med, nGap_med, SEI_med, nGMD_med).
    * Applies MAD-z normalization to the entire table and calculates a weighted sum => Composite_madz.
    * Merges with the current run's convergence metrics (steps_to_kl80/90, reward_auc, etc.) and saves the result.

New/Compatible Arguments:
--rb_metrics_csv (alias of --rm_metrics_csv)
--joint_root
--slope_head (int steps for early-slope window)
--auc_norm_to (int, normalize reward AUC by this step count)

Environment Variables:
- If RM_NAME is set, it overrides the --rm_name command-line argument.
"""

import os, re, json, argparse, random, time, math, csv, shutil
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

try:
    from transformers import LogitsProcessorList, InfNanRemoveLogitsProcessor
    HAS_INFPROC = True
except Exception:
    LogitsProcessorList, InfNanRemoveLogitsProcessor = None, None
    HAS_INFPROC = False


# ------------------------ Utils ------------------------
def set_seed_all(seed:int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def leftpad_tokenizer(name:str):
    tok = AutoTokenizer.from_pretrained(name, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None: tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    return tok

def first_human_turn(text:str)->Optional[str]:
    if not isinstance(text,str): return None
    m = re.search(r"Human:\s*(.*?)\n+Assistant:", text, flags=re.S)
    if m: return re.sub(r"\s+"," ", m.group(1).strip())
    return text.strip()[:800]

def get_hh_prompts(n=4096, seed=42):
    ds = load_dataset("Anthropic/hh-rlhf", split="train")
    buf=[]
    for ex in ds:
        p = first_human_turn(ex.get("chosen","")) or first_human_turn(ex.get("rejected",""))
        if p: buf.append(p)
    random.Random(seed).shuffle(buf)
    return buf[:n]

def p90_p10(vals):
    if not vals: return float("nan")
    v=np.sort(np.asarray(vals, float)); kH=(len(v)-1)*0.9; kL=(len(v)-1)*0.1
    def lerp(a,k):
        f=int(np.floor(k)); c=int(np.ceil(k))
        return float(a[f] if f==c else a[f]*(c-k)+a[c]*(k-f))
    return lerp(v,kH)-lerp(v,kL)

def sanitize(mid: str) -> str:
    return mid.replace("/", "_").replace(":", "_")

def trim_after_pad_eos(seq: torch.Tensor, pad_id: Optional[int], eos_id: Optional[int]) -> torch.Tensor:
    """Trims a generated sequence: truncates at the first pad or eos token (includes eos, excludes pad)."""
    if not isinstance(seq, torch.Tensor):
        return seq
    L = int(seq.numel())
    cut = L
    if pad_id is not None:
        pos = (seq == pad_id).nonzero(as_tuple=False)
        if pos.numel() > 0:
            cut = min(cut, int(pos[0]))
    if eos_id is not None:
        pos = (seq == eos_id).nonzero(as_tuple=False)
        if pos.numel() > 0:
            cut = min(cut, int(pos[0]) + 1)  # keep the eos token
    seq = seq[:cut]
    if seq.numel() == 0:
        # Failsafe: keep at least 1 token to prevent an all-zero mask later.
        fallback = eos_id if eos_id is not None else (pad_id if pad_id is not None else 0)
        seq = torch.tensor([fallback], device=seq.device, dtype=seq.dtype)
    return seq


# ------------------------ RM registry ------------------------
RM_ALIASES = {
    "Skywork-Reward-Llama-3.1-8B":      "Skywork/Skywork-Reward-Llama-3.1-8B",
    "Skywork-Reward-V2-Llama-3.1-8B":   "Skywork/Skywork-Reward-Llama-3.1-8B",
    "tulu-v2.5-13b-uf-rm":              "allenai/tulu-v2.5-13b-uf-rm",
    "beaver-7b-v2.0-reward":            "PKU-Alignment/beaver-7b-v2.0-reward",
    "Skywork-Reward-V2-Qwen3-1.7B":     "Skywork/Skywork-Reward-V2-Qwen3-1.7B",
    "Skywork-Reward-V2-Qwen3-8B":       "Skywork/Skywork-Reward-V2-Qwen3-8B",
    "Skywork-Reward-V2-Qwen3-4B":       "Skywork/Skywork-Reward-V2-Qwen3-4B",
    "RM-Mistral-7B":                    "weqweasdas/RM-Mistral-7B",
    "GRM-Llama3-8B-rewardmodel-ft":     "Ray2333/GRM-Llama3-8B-rewardmodel-ft",
    "BTRM_Qwen2_7b_0613":               "CIR-AMS/BTRM_Qwen2_7b_0613",
}
def resolve_rm_name(name:str)->str:
    return RM_ALIASES.get(name, name)


# ------------------------ Safe math helpers ------------------------
def _log_softmax_safe(logits: torch.Tensor) -> torch.Tensor:
    """Robust log_softmax for logits, cleaning NaN/Inf to prevent all-NaN rows."""
    # First, set non-finite values to 0 to prevent torch.max from propagating NaN.
    logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits))
    # Row-level stabilization: subtract the max value of each row.
    row_max = logits.max(dim=-1, keepdim=True).values
    logits = logits - row_max
    logp = F.log_softmax(logits, dim=-1)
    # Clean again (non-finite values can still occur in extreme cases).
    logp = torch.nan_to_num(logp, nan=-1e4, posinf=0.0, neginf=-1e4)
    return logp

def _gather_logprobs_safe(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """Returns the log_prob(labels) for each position, using the safe log_softmax internally."""
    logp = _log_softmax_safe(logits)
    out = logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    out = torch.nan_to_num(out, nan=-1e4, posinf=0.0, neginf=-1e4)
    return out


# ------------------------ RM scorer ------------------------
class StrongRM:
    """A wrapper for HuggingFace sequence classification RMs. Supports 8/4-bit quantization. Outputs z-scored rewards clipped to [-5,5] as float32 tensors."""
    def __init__(self, rm_model:str, device:str="cuda:1", max_len:int=512, label_index:int=0,
                 quant:str="8bit", torch_dtype:str="bfloat16"):
        self.device = device
        self.max_len = max_len
        self.label_index = label_index
        self.tok = AutoTokenizer.from_pretrained(rm_model, use_fast=True, trust_remote_code=True)

        dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
        hf_dtype = dtype_map.get(torch_dtype, torch.bfloat16)

        qconf = None
        if quant == "8bit":
            qconf = BitsAndBytesConfig(load_in_8bit=True)
        elif quant == "4bit":
            qconf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                                       bnb_4bit_quant_type="nf4",
                                       bnb_4bit_compute_dtype=hf_dtype if hf_dtype in (torch.float16, torch.bfloat16) else torch.float16)

        kwargs = dict(trust_remote_code=True)
        if qconf is not None:
            kwargs["quantization_config"] = qconf
            kwargs["device_map"] = {"": device}
        else:
            kwargs["torch_dtype"] = hf_dtype

        self.rm = AutoModelForSequenceClassification.from_pretrained(rm_model, **kwargs)
        if qconf is None: self.rm = self.rm.to(device)
        self.rm.eval()

    @torch.no_grad()
    def score(self, prompts:List[str], responses:List[str])->torch.Tensor:
        texts = [f"Human: {p}\n\nAssistant: {r}" for p,r in zip(prompts, responses)]
        enc = self.tok(texts, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
        out = self.rm(**enc)
        logits = out.logits.squeeze(-1)
        if logits.dim()==2: logits = logits[:, self.label_index]
        mu = logits.mean(); sigma = logits.std(unbiased=False).clamp_min(1e-4)
        rew = ((logits - mu) / sigma).clamp(-5,5)
        return torch.nan_to_num(rew, nan=0.0, posinf=5.0, neginf=-5.0).to(torch.float32)


# ------------------------ KL/Entropy (Robust version) ------------------------
@torch.no_grad()
def compute_tokenwise_kl_nll(tok, policy, ref_model, queries, responses, device):
    """
    queries/responses: List of 1D LongTensors (already on the target device).
    Returns: (kl_tok, nll_tok) averaged over response token positions.
    """
    def get_base(m):
        return getattr(m, "pretrained_model", m)

    policy_base = get_base(policy)
    ref_base = get_base(ref_model) if ref_model is not None else None

    pad_id = tok.pad_token_id or tok.eos_token_id
    fulls, q_lens = [], []
    max_len = 0
    for q, r in zip(queries, responses):
        full = torch.cat([q, r], dim=0)
        fulls.append(full)
        q_lens.append(int(q.numel()))
        max_len = max(max_len, int(full.numel()))

    B = len(fulls)
    ids = torch.full((B, max_len), fill_value=pad_id, dtype=torch.long, device=device)
    attn = torch.zeros((B, max_len), dtype=torch.long, device=device)
    resp_mask = torch.zeros((B, max_len - 1), dtype=torch.float32, device=device)  # Align with labels length L-1
    for b, (full, qlen) in enumerate(zip(fulls, q_lens)):
        L = int(full.numel())
        ids[b, :L] = full
        attn[b, :L] = 1
        start = max(0, qlen - 1)
        if L - 1 > start:
            resp_mask[b, start:L-1] = 1.0

    out_pol = policy_base(input_ids=ids, attention_mask=attn)
    logits_pol = out_pol.logits[:, :-1, :]
    labels = ids[:, 1:]
    lp_pol = _gather_logprobs_safe(logits_pol, labels)

    if ref_base is None:
        kl_tok = float("nan")
    else:
        out_ref = ref_base(input_ids=ids, attention_mask=attn)
        logits_ref = out_ref.logits[:, :-1, :]
        lp_ref = _gather_logprobs_safe(logits_ref, labels)
        diff = (lp_pol - lp_ref) * resp_mask
        denom = resp_mask.sum().clamp_min(1.0)
        kl_tok = float(diff.sum().item() / denom.item())

    nll = (-lp_pol) * resp_mask
    nll_tok = float(nll.sum().item() / resp_mask.sum().clamp_min(1.0).item())
    return kl_tok, nll_tok

def _get_base_lm(m):
    return getattr(m, "pretrained_model", m)

def _pad_stack(seqs, pad_id: int):
    lens = [int(x.numel()) for x in seqs]
    L = max(lens)
    out = []
    dev = seqs[0].device
    for x in seqs:
        if x.numel() < L:
            pad = torch.full((L - x.numel(),), pad_id, dtype=x.dtype, device=dev)
            out.append(torch.cat([x, pad], dim=0))
        else:
            out.append(x)
    batch = torch.stack(out, dim=0)
    return batch, lens

@torch.no_grad()
def compute_kl_entropy_tokenwise(policy, ref_model, queries, responses, pad_id: int) -> Tuple[float, float]:
    """
    Returns token-level KL and NLL (using NLL as a robust proxy for entropy).
    * Strictly counts only tokens in the "response" section: from (prompt_len-1) to (L-2).
    * Excludes right-side padding to avoid dilution.
    * Cleans NaN/Inf from logits/log-probs to avoid 0 or NaN results.
    """
    if ref_model is None:
        return float("nan"), float("nan")

    def _base(m):
        return getattr(m, "pretrained_model", m)

    pi = _base(policy)
    rf = _base(ref_model)

    fulls, qlens, Lmax = [], [], 0
    for q, r in zip(queries, responses):
        f = torch.cat([q, r], dim=0)
        fulls.append(f)
        qlens.append(int(q.numel()))
        Lmax = max(Lmax, int(f.numel()))

    B = len(fulls)
    device = fulls[0].device
    ids    = torch.full((B, Lmax), pad_id, dtype=torch.long, device=device)
    attn   = torch.zeros((B, Lmax), dtype=torch.long, device=device)
    rmask  = torch.zeros((B, Lmax - 1), dtype=torch.float32, device=device)
    for i, (f, ql) in enumerate(zip(fulls, qlens)):
        L = int(f.numel())
        ids[i, :L]  = f
        attn[i, :L] = 1
        st = max(0, ql - 1)
        if L - 1 > st:
            rmask[i, st:L - 1] = 1.0

    out_pi = pi(input_ids=ids, attention_mask=attn)
    out_rf = rf(input_ids=ids, attention_mask=attn)

    logits_pi = out_pi.logits[:, :-1, :]
    logits_rf = out_rf.logits[:, :-1, :]
    tgt       = ids[:, 1:]

    # Count and print once if any non-finite values are found.
    n_bad_pi = (~torch.isfinite(logits_pi)).sum().item()
    n_bad_rf = (~torch.isfinite(logits_rf)).sum().item()
    if (n_bad_pi + n_bad_rf) > 0:
        print(f"[KL/NLL] cleaned non-finite logits: pi={n_bad_pi} rf={n_bad_rf}")

    logp_pi = _gather_logprobs_safe(logits_pi, tgt)
    logp_rf = _gather_logprobs_safe(logits_rf, tgt)

    diff    = logp_pi - logp_rf
    diff    = torch.nan_to_num(diff, nan=0.0, posinf=0.0, neginf=0.0)

    rmask   = rmask.to(diff.device)
    lengths = rmask.sum(dim=1).clamp_min(1.0)

    kl_seq  = (diff * rmask).sum(dim=1) / lengths
    nll_seq = (-(logp_pi * rmask).sum(dim=1) / lengths)

    def _mean_ok(x):
        msk = torch.isfinite(x)
        return x[msk].mean().item() if msk.any() else float("nan")

    return float(_mean_ok(kl_seq)), float(_mean_ok(nll_seq))


# ------------------------ Args ------------------------
@dataclass
class Args:
    # policy / ref
    model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
    dtype: str = "float16"
    policy_device: str = "cuda:0"
    ref_mode: str = "copy"      # none|copy|8bit
    ref_device: str = "cuda:0"
    policy_quant: str = "4bit"  # none|4bit|8bit

    # train
    steps: int = 600
    batch_size: int = 8
    mini_batch_size: int = 2
    ppo_epochs: int = 4
    lr: float = 1e-5
    init_kl: float = 0.02
    target_kl: float = 0.2

    # gen
    max_prompt_tokens: int = 256
    gen_max_new: int = 64
    temp: float = 0.9
    top_p: float = 0.95

    # rm
    use_rm: int = 1
    rm_name: str = "Skywork-Reward-Llama-3.1-8B"
    rm_device: str = "cuda:1"
    rm_quant: str = "8bit"          # none|8bit|4bit
    rm_dtype: str = "bfloat16"
    rm_max_len: int = 512
    rm_label_index: int = 0

    # misc / io
    seed: int = 123
    save_dir: str = "./rlhf_runs/ppo_hh_v7"
    rm_metrics_csv: str = "./precomputed_metrics/rm_global_metrics.csv"

    # extras for post-run
    joint_root: Optional[str] = None
    slope_head: Optional[int] = None
    auc_norm_to: int = 0

def parse_args() -> Args:
    p = argparse.ArgumentParser()
    g = p.add_argument
    # === Original arguments ===
    g("--model_name", type=str, default=Args.model_name)
    g("--dtype", type=str, default=Args.dtype, choices=["float16","bfloat16","float32"])
    g("--policy_device", type=str, default=Args.policy_device)
    g("--ref_mode", type=str, default=Args.ref_mode, choices=["none","copy","8bit"])
    g("--ref_device", type=str, default=Args.ref_device)
    g("--policy_quant", type=str, default=Args.policy_quant, choices=["none","4bit","8bit"])

    g("--steps", type=int, default=Args.steps)
    g("--batch_size", type=int, default=Args.batch_size)
    g("--mini_batch_size", type=int, default=Args.mini_batch_size)
    g("--ppo_epochs", type=int, default=Args.ppo_epochs)
    g("--lr", type=float, default=Args.lr)
    g("--init_kl", type=float, default=Args.init_kl)
    g("--target_kl", type=float, default=Args.target_kl)

    g("--max_prompt_tokens", type=int, default=Args.max_prompt_tokens)
    g("--gen_max_new", type=int, default=Args.gen_max_new)
    g("--temp", type=float, default=Args.temp)
    g("--top_p", type=float, default=Args.top_p)

    g("--use_rm", type=int, default=Args.use_rm)
    g("--rm_name", type=str, default=Args.rm_name)
    g("--rm_device", type=str, default=Args.rm_device)
    g("--rm_quant", type=str, default=Args.rm_quant, choices=["none","8bit","4bit"])
    g("--rm_dtype", type=str, default=Args.rm_dtype, choices=["float16","bfloat16","float32"])
    g("--rm_max_len", type=int, default=Args.rm_max_len)
    g("--rm_label_index", type=int, default=Args.rm_label_index)

    g("--seed", type=int, default=Args.seed)
    g("--save_dir", type=str, default=Args.save_dir)

    # === New/compatible arguments ===
    g("--rm_metrics_csv", type=str, default=Args.rm_metrics_csv)
    g("--rb_metrics_csv", type=str, default=None)
    g("--joint_root", type=str, default=Args.joint_root)
    g("--slope_head", type=int, default=Args.slope_head)
    g("--auc_norm_to", type=int, default=Args.auc_norm_to)

    ns = p.parse_args()

    # Alias -> main argument
    if getattr(ns, "rb_metrics_csv", None):
        ns.rm_metrics_csv = ns.rb_metrics_csv

    # RM_NAME environment variable overrides command line
    env_rm = os.environ.get("RM_NAME")
    if env_rm:
        ns.rm_name = env_rm.strip()
        print(f"[Args] Overriding rm_name from ENV RM_NAME={ns.rm_name}")

    allowed = set(Args.__annotations__.keys())
    filtered = {k: getattr(ns, k) for k in allowed if hasattr(ns, k)}
    return Args(**filtered)


# ------------------------ Build policy/ref ------------------------
# New build_policy function (quantization removed)
def build_policy(args:Args, tok):
    dtype_map={"float16":torch.float16,"bfloat16":torch.bfloat16,"float32":torch.float32}
    torch_dtype=dtype_map.get(args.dtype, torch.bfloat16) # ensure a default value

    # Load model directly without BitsAndBytesConfig
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        args.model_name,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
    )
    model = model.to(args.policy_device) # Move the model to the specified device

    try:
        base=getattr(model,"pretrained_model",None)
        if base is not None and hasattr(base,"config"): base.config.use_cache=False
        if hasattr(model,"config"): model.config.use_cache=False
    except Exception: pass
    
    print(f"[Policy] Loaded model on {args.policy_device} with dtype {torch_dtype} (quantization disabled).")
    return model

def build_ref(args:Args):
    if args.ref_mode=="none": return None
    if args.ref_mode=="8bit":
        qconf=BitsAndBytesConfig(load_in_8bit=True)
        ref=AutoModelForCausalLMWithValueHead.from_pretrained(
            args.model_name, trust_remote_code=True, quantization_config=qconf, device_map={"":args.ref_device}
        )
        print("[Ref] 8-bit on", args.ref_device); return ref
    ref=AutoModelForCausalLMWithValueHead.from_pretrained(
        args.model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map={"":args.ref_device}
    )
    try:
        base=getattr(ref,"pretrained_model",None)
        if base is not None and hasattr(base,"config"): base.config.use_cache=False
    except Exception: pass
    print("[Ref] fp16 copy on", args.ref_device)
    return ref


# ------------------------ Train & Log ------------------------
def main_train_and_log(args: Args) -> Dict[str, Any]:
    print("[Env] CUDA_VISIBLE_DEVICES =", os.getenv("CUDA_VISIBLE_DEVICES"))
    os.makedirs(args.save_dir, exist_ok=True)
    set_seed_all(args.seed)

    rm_resolved = resolve_rm_name(args.rm_name)

    tok = leftpad_tokenizer(args.model_name)
    print(f"[Tok] padding_side={tok.padding_side} pad={tok.pad_token_id} eos={tok.eos_token_id}")

    policy = build_policy(args, tok)
    ref_model = build_ref(args)

    # Step 1: Create a PPOConfig object for core algorithm hyperparameters.
    cfg = PPOConfig(
        batch_size=args.batch_size,
        mini_batch_size=args.mini_batch_size,
        learning_rate=args.lr,
        seed=args.seed,
        vf_coef=0.1,
        gradient_accumulation_steps=args.batch_size // args.mini_batch_size,
    )

    # Step 2: Create the PPOTrainer, passing the config object and other non-algorithm parameters.
    trainer = PPOTrainer(
        config=cfg,
        model=policy,
        ref_model=ref_model,
        tokenizer=tok,
        dataset=None,
        num_ppo_epochs=args.ppo_epochs,
        kl_penalty="kl",
        initial_kl_coef=args.init_kl,
        target_kl=args.target_kl
    )
    print(f"[Dev] accelerator.device={trainer.accelerator.device} | world={trainer.accelerator.num_processes}")

    rm=None
    if int(args.use_rm)==1:
        print(f"[RM] using: {rm_resolved} | quant={args.rm_quant}+{args.rm_dtype} on {args.rm_device}")
        rm = StrongRM(rm_model=rm_resolved, device=args.rm_device, max_len=args.rm_max_len,
                      label_index=args.rm_label_index, quant=args.rm_quant, torch_dtype=args.rm_dtype)

    prompts_all = get_hh_prompts(n=4096, seed=args.seed)
    logits_proc = LogitsProcessorList([InfNanRemoveLogitsProcessor()]) if HAS_INFPROC else None

    # sanity check
    policy.eval()
    with torch.no_grad():
        enc = tok([prompts_all[0]], return_tensors="pt", padding=True, truncation=True,
                   max_length=args.max_prompt_tokens).to(args.policy_device)
        gen_kwargs = dict(
            max_new_tokens=max(1, args.gen_max_new),
            temperature=max(args.temp, 1e-5),
            top_p=args.top_p,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.pad_token_id,
        )
        try:
            g = policy.generate(**enc, do_sample=True,
                                logits_processor=logits_proc, remove_invalid_values=True,
                                **gen_kwargs)
        except RuntimeError:
            g = policy.generate(**enc, do_sample=False,
                                logits_processor=logits_proc, remove_invalid_values=True,
                                **gen_kwargs)
        in_len=int(enc["attention_mask"][0].sum().item())
        resp_raw = g[0][in_len:]
        resp_trim = trim_after_pad_eos(resp_raw, tok.pad_token_id, tok.eos_token_id)
        print("[Sanity] first resp_len=", int(resp_trim.numel()))
    policy.train()

    csv_path = os.path.join(args.save_dir, "train_log.csv")
    if not os.path.exists(csv_path):
        with open(csv_path,"w",encoding="utf-8") as f:
            f.write("step,reward_mean,reward_std,reward_iqr,resp_len,kl,entropy,approx_kl,kl_trl,entropy_trl,kl_tok,entropy_tok\n")

    step = 0
    # Generation parameters during training (explicit eos/pad)
    gen_kwargs=dict(
        max_new_tokens=args.gen_max_new,
        temperature=max(args.temp,1e-5),
        top_p=args.top_p,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )

    last_row = {}
    while step < args.steps:
        batch_prompts = prompts_all[step: step + args.batch_size]
        if not batch_prompts: break

        enc=tok(batch_prompts, return_tensors="pt", padding=True, truncation=True,
                max_length=args.max_prompt_tokens).to(args.policy_device)
        input_ids=enc["input_ids"]; attn=enc["attention_mask"]

        policy.eval()
        with torch.no_grad():
            try:
                gen=policy.generate(**enc, do_sample=True,
                                    logits_processor=logits_proc, remove_invalid_values=True, **gen_kwargs)
            except RuntimeError:
                gen=policy.generate(**enc, do_sample=False,
                                    logits_processor=logits_proc, remove_invalid_values=True, **gen_kwargs)
        policy.train()

        queries,responses,resp_masks,texts=[],[],[],[]
        for b in range(input_ids.size(0)):
            in_len=int(attn[b].sum().item())
            q=input_ids[b][attn[b].bool()]
            r_raw=gen[b][in_len:]
            r = trim_after_pad_eos(r_raw, tok.pad_token_id, tok.eos_token_id)
            queries.append(q)
            responses.append(r)
            resp_masks.append(torch.ones_like(r, dtype=torch.float32, device=args.policy_device))
            texts.append(tok.decode(r, skip_special_tokens=True).strip())

        # ---- Debug: Confirm the true number of response tokens, ensuring padding is not counted ----
        if step % max(1, args.batch_size) == 0:
            true_resp_toks = sum(int(r.numel()) for r in responses)
            avg_len = float(np.mean([int(x.numel()) for x in responses]))
            print(f"[Debug step {step}] avg_resp_len={avg_len:.1f}, true_resp_toks={true_resp_toks}")

        if rm is not None:
            with torch.no_grad():
                rs = rm.score(batch_prompts, texts).to(args.policy_device)
        else:
            rs = torch.zeros(len(batch_prompts), device=args.policy_device, dtype=torch.float32)

        # Robust KL/Entropy (token-level)
        kl_tok, ent_tok = compute_kl_entropy_tokenwise(
            policy, ref_model, queries, responses, pad_id=tok.pad_token_id or tok.eos_token_id
        )

        # PPO step (compatible with TRL versions that may or may not support response_masks)
        try:
            stats = trainer.step(queries, responses, [s.view(()) for s in rs], response_masks=resp_masks)
        except TypeError:
            stats = trainer.step(queries, responses, [s.view(()) for s in rs])

        def first_key(d, keys):
            for k in keys:
                if k in d:
                    try:
                        return float(d[k])
                    except Exception:
                        return None
            return None

        kl_trl = first_key(stats, ["ppo/policy/kl","ppo/kl","kl","stats/kl","objective/kl"])
        ent_trl = first_key(stats, ["policy/entropy","ppo/policy/entropy","objective/entropy"])
        approx_kl = first_key(stats, ["approx_kl","ppo/policy/approxkl"]) or 0.0

        def pick(main, backup, default=0.0):
            if (main is not None) and math.isfinite(main): return float(main)
            if (backup is not None) and math.isfinite(backup): return float(backup)
            return float(default)

        r_mean=float(rs.mean().item()); r_std=float(rs.std(unbiased=False).item())
        r_iqr=float(p90_p10(rs.detach().cpu().tolist()))
        resp_len=float(np.mean([int(x.numel()) for x in responses]))
        kl_val  = pick(kl_tok,  kl_trl,  0.0)
        ent_val = pick(ent_tok, ent_trl, 0.0)

        last_row = dict(
            step=float(step),
            reward_mean=r_mean, reward_std=r_std, reward_iqr=r_iqr, resp_len=resp_len,
            kl=kl_val, entropy=ent_val, approx_kl=float(approx_kl or 0.0),
            kl_trl=float(kl_trl) if (kl_trl is not None and math.isfinite(kl_trl)) else 0.0,
            entropy_trl=float(ent_trl) if (ent_trl is not None and math.isfinite(ent_trl)) else 0.0,
            kl_tok=float(kl_tok) if (kl_tok is not None and math.isfinite(kl_tok)) else 0.0,
            entropy_tok=float(ent_tok) if (ent_tok is not None and math.isfinite(ent_tok)) else 0.0,
        )

        if step % max(1,args.batch_size) == 0:
            print(f"[Step {step}] r_mean={r_mean:.4f} std={r_std:.4f} kl={kl_val:.3f} ent={ent_val:.3f} "
                  f"(kl_tok={kl_tok if math.isfinite(kl_tok) else float('nan'):.3f}, kl_trl={kl_trl if (kl_trl is not None and math.isfinite(kl_trl)) else float('nan')})")

        with open(csv_path,"a",encoding="utf-8") as f:
            f.write("{step},{reward_mean},{reward_std},{reward_iqr},{resp_len},{kl},{entropy},{approx_kl},"
                    "{kl_trl},{entropy_trl},{kl_tok},{entropy_tok}\n".format(**last_row))

        step += args.batch_size

    final_dir=os.path.join(args.save_dir,"final"); os.makedirs(final_dir, exist_ok=True)
    policy.save_pretrained(final_dir); tok.save_pretrained(final_dir)
    print(f"[Done] saved to {final_dir}")

    meta = dict(
        model_name=args.model_name,
        rm_name_input=args.rm_name,
        rm_name_resolved=resolve_rm_name(args.rm_name),
        seed=args.seed,
        steps=args.steps,
        batch_size=args.batch_size,
        lr=args.lr,
        target_kl=args.target_kl,
        save_dir=args.save_dir,
        train_log_csv=csv_path,
        slope_head=args.slope_head,
        auc_norm_to=args.auc_norm_to,
        rm_metrics_csv=args.rm_metrics_csv,
        joint_root=args.joint_root,
    )
    with open(os.path.join(args.save_dir, "run_meta.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)
    return meta


# ------------------------ Post-run Joint Analysis ------------------------
def _read_training_log(csv_path:str) -> Dict[str, List[float]]:
    xs, rmean, kl = [], [], []
    if not os.path.isfile(csv_path): return {"step":xs, "reward_mean":rmean, "kl":kl}
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                xs.append(float(row["step"]))
                rmean.append(float(row["reward_mean"]))
                kl.append(float(row["kl"]))
            except Exception:
                continue
    return {"step":xs, "reward_mean":rmean, "kl":kl}

def _steps_to_threshold(xs: List[float], ys: List[float], thr: float) -> int:
    for x,y in zip(xs, ys):
        if y >= thr:
            return int(x)
    return -1

def _trapezoid_auc(xs: List[float], ys: List[float]) -> float:
    if len(xs) < 2: return 0.0
    s = 0.0
    for i in range(1, len(xs)):
        dx = xs[i] - xs[i-1]
        s += (ys[i] + ys[i-1]) * 0.5 * dx
    return float(s)

def _linreg_slope(xs: List[float], ys: List[float]) -> float:
    if len(xs) < 2: return 0.0
    x = np.asarray(xs, float); y = np.asarray(ys, float)
    x = x - x.mean(); y = y - y.mean()
    denom = (x*x).sum()
    if denom <= 1e-12: return 0.0
    return float((x*y).sum() / denom)

def _mad_z(values: List[float]) -> Tuple[float, float]:
    arr = np.asarray(values, float)
    med = float(np.median(arr))
    mad = float(np.median(np.abs(arr - med)))
    return med, mad

def _make_rm_key_candidates(rm_resolved: str) -> List[str]:
    san = sanitize(rm_resolved)
    tail = rm_resolved.split("/")[-1]
    return [f"rb_{san}/{rm_resolved}", f"rb_{san}/{tail}", tail]

def _load_rm_metrics_row(csv_path: str, rm_resolved: str) -> Tuple[Dict[str, str], List[Dict[str, str]]]:
    if not os.path.isfile(csv_path):
        print(f"[Joint] RM metrics CSV not found: {csv_path}")
        return {}, []

    all_rows = []
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            all_rows.append(row)

    cands = _make_rm_key_candidates(rm_resolved)
    hit = {}
    for row in all_rows:
        rk = row.get("rm_key","")
        if not rk: continue
        if rk in cands or any(rk.endswith("/"+c) for c in cands if "/" not in c) or any(c in rk for c in cands if "/" not in c):
            hit = row; break

    if not hit:
        tail = rm_resolved.split("/")[-1]
        for row in all_rows:
            mn = row.get("model_name","")
            if mn == rm_resolved or mn == tail:
                hit = row; break

    if not hit:
        print(f"[Joint] No matching rm_key/model_name in CSV for: {rm_resolved} | tried {cands}")
    else:
        print(f"[Joint] Matched RM row: rm_key={hit.get('rm_key','')} model_name={hit.get('model_name','')}")
    return hit, all_rows

def _compute_composite_madz(hit: Dict[str,str], all_rows: List[Dict[str,str]]) -> Dict[str, float]:
    keys = ["RSI_IQR_med", "nGap_med", "SEI_med"]
    cols = {k: [] for k in keys}
    for row in all_rows:
        for k in keys:
            try:
                cols[k].append(float(row[k]))
            except Exception:
                pass

    zsum = 0.0
    parts = {}
    for k in keys:
        if k not in hit:
            parts[k] = float("nan")
            continue
        try:
            x = float(hit[k])
        except Exception:
            parts[k] = float("nan")
            continue
        med = float(np.median(np.asarray(cols[k], float))) if cols[k] else 0.0
        mad = float(np.median(np.abs(np.asarray(cols[k], float) - med))) if cols[k] else 0.0
        scale = (mad * 1.4826)
        if (not math.isfinite(scale)) or scale <= 1e-12:
            arr = np.asarray(cols[k], float)
            if arr.size >= 4:
                iqr = float(np.percentile(arr, 75) - np.percentile(arr, 25))
                scale = iqr / 1.349 if iqr > 1e-12 else (float(arr.std()) or 1.0)
            else:
                scale = float(arr.std()) if arr.size > 1 else 1.0
        z = (x - med) / max(scale, 1e-12)
        parts[k] = float(z)
        zsum += float(z)

    parts["Composite_madz"] = float(zsum)
    return parts

def post_run_joint_analysis(meta: Dict[str,Any]):
    rm_resolved = meta.get("rm_name_resolved", meta.get("rm_name_input",""))
    csv_train = meta["train_log_csv"]
    rm_metrics_csv = meta["rm_metrics_csv"]
    slope_head = meta.get("slope_head", None)
    auc_norm_to = int(meta.get("auc_norm_to", 0) or 0)

    train = _read_training_log(csv_train)
    xs, rmean, kl = train["step"], train["reward_mean"], train["kl"]

    target_kl = float(meta.get("target_kl", 0.2))
    kl80 = 0.8 * target_kl
    kl90 = 0.9 * target_kl
    steps_to_kl80 = _steps_to_threshold(xs, kl, kl80)
    steps_to_kl90 = _steps_to_threshold(xs, kl, kl90)
    reward_auc = _trapezoid_auc(xs, rmean)
    reward_auc_norm = (reward_auc / float(auc_norm_to)) if (auc_norm_to and auc_norm_to > 0) else None

    if slope_head is not None and slope_head > 1:
        idx = [i for i,x in enumerate(xs) if x <= slope_head]
        n = (idx[-1]+1) if idx else min(len(xs), 3)
    else:
        n = max(3, int(len(xs) * 0.33))
    slope_early = _linreg_slope(xs[:n], rmean[:n]) if len(xs) >= 2 else 0.0

    final_reward = float(rmean[-1]) if rmean else 0.0
    final_kl = float(kl[-1]) if kl else 0.0

    hit, all_rows = _load_rm_metrics_row(rm_metrics_csv, rm_resolved)
    rm_part = {}
    if hit:
        rm_part = dict(
            rm_key=hit.get("rm_key",""),
            rm_model_name=hit.get("model_name",""),
            RSI_IQR_med=float(hit.get("RSI_IQR_med", "nan")) if hit.get("RSI_IQR_med","")!="" else float("nan"),
            nGap_med=float(hit.get("nGap_med", "nan")) if hit.get("nGap_med","")!="" else float("nan"),
            SEI_med=float(hit.get("SEI_med", "nan")) if hit.get("SEI_med","")!="" else float("nan"),
            nGMD_med=float(hit.get("nGMD_med", "nan")) if hit.get("nGMD_med","")!="" else float("nan"),
        )
        comps = _compute_composite_madz(hit, all_rows)
        rm_part.update(comps)
    else:
        print("[Joint] WARNING: RM row not found; rm metrics will be empty")

    out = dict(
        run_dir=meta["save_dir"],
        rm_metrics_csv=rm_metrics_csv,
        **rm_part,
        steps_to_kl80=int(steps_to_kl80),
        steps_to_kl90=int(steps_to_kl90),
        reward_auc=float(reward_auc),
        reward_auc_norm=float(reward_auc_norm) if (reward_auc_norm is not None) else float("nan"),
        reward_slope_early=float(slope_early),
        final_reward_mean=float(final_reward),
        final_kl=float(final_kl),
    )

    out_dir = os.path.join(meta["save_dir"], "joint")
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, "joint_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)

    csv_out = os.path.join(out_dir, "joint_metrics.csv")
    with open(csv_out, "w", encoding="utf-8", newline="") as f:
        fieldnames = list(out.keys())
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerow(out)
    print(f"[Joint] saved joint metrics → {csv_out}")

    joint_root = meta.get("joint_root", None)
    if joint_root:
        rm_tail = (rm_resolved.split("/")[-1] if rm_resolved else "unknown_rm")
        base = os.path.basename(meta["save_dir"].rstrip("/"))
        dst_dir = os.path.join(joint_root, rm_tail)
        os.makedirs(dst_dir, exist_ok=True)
        dst_csv = os.path.join(dst_dir, f"{base}.csv")
        try:
            shutil.copy2(csv_out, dst_csv)
            print(f"[Joint] copied to aggregator → {dst_csv}")
        except Exception as e:
            print(f"[Joint] copy to aggregator failed: {e}")


# ------------------------ Entry ------------------------
def main():
    try:
        torch.backends.cuda.matmul.allow_tf32=True; torch.backends.cudnn.allow_tf32=True
    except Exception:
        pass
    os.environ.setdefault("TOKENIZERS_PARALLELISM","false")

    args = parse_args()
    meta = main_train_and_log(args)
    post_run_joint_analysis(meta)
    print(f"== Done: {meta['save_dir']} ==")


if __name__ == "__main__":
    main()