import os
import re
import json
import gc
import math
import argparse
import random
import hashlib
import time
import contextlib
import copy
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple
from datetime import datetime

import glob

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList, Qwen2TokenizerFast
from datasets import load_dataset
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from peft import LoraConfig, get_peft_model, PeftModel, set_peft_model_state_dict

# parser: resource from qwen-2.5 math
# source: https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/math_eval.py
from parser import parse_ground_truth, parse_question, math_equal, extract_answer
from yuumi_utils import *

# -----------------------------
# -----------------------------
ts = datetime.now().strftime("%m-%d-%H_%M_%S")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def _count_trainable_params(model: torch.nn.Module) -> int:
    """Count trainable parameters (numel)."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def _get_output_embedding_weight_param(model: torch.nn.Module) -> Optional[torch.nn.Parameter]:
    """Robustly locate lm_head.weight (works with PEFT wrappers)."""
    try:
        out = model.get_output_embeddings() if hasattr(model, "get_output_embeddings") else None
        if out is not None and hasattr(out, "weight"):
            return out.weight
    except Exception:
        pass
    if hasattr(model, "lm_head") and hasattr(model.lm_head, "weight"):
        return model.lm_head.weight
    return None

def _maybe_freeze_input_embeddings(model: torch.nn.Module, out_weight: torch.nn.Parameter, log_fn=print) -> None:
    """Freeze embed_tokens if NOT tied with lm_head (avoid unintended drift)."""
    try:
        inp = model.get_input_embeddings() if hasattr(model, "get_input_embeddings") else None
        if inp is None or not hasattr(inp, "weight"):
            return
        in_w = inp.weight
        tied = (in_w is out_weight) or (in_w.data_ptr() == out_weight.data_ptr())
        if (not tied) and in_w.requires_grad:
            in_w.requires_grad_(False)
            log_fn("[ZIP-Train] Froze input embedding weight (embed_tokens) to avoid unintended updates.")
    except Exception as e:
        log_fn(f"[ZIP-Train] [WARN] Failed to inspect/freeze input embeddings: {e}")

def generate_strategy(strategy_demand, slm_multi_k, flipped_target=0.0, sc_k=5):
    strategies = []
    if "slm_only" in strategy_demand:
        strategies.append({"name": "SLM-only", "mode": "slm_only"})
    if "slm_multi" in strategy_demand and "multi" in strategy_demand:
        strategies.append({"name": "SLM-only (multi)", "mode": "slm_only", "sc_k": sc_k})
    if "llm_only" in strategy_demand:
        strategies.append({"name": "LLM-only", "mode": "llm_only"})
    if "llm_multi" in strategy_demand:
        strategies.append({"name": "LLM-only (multi)", "mode": "llm_only", "do_sample": True, "sc_k": sc_k, "temperature": 0.6, "top_p": 0.95})
    if "en" in strategy_demand:
        strategies.append({"name": "Ablation-entropy", "mode": "s2t", "trigger_type": "entropy"})
    if "kl" in strategy_demand:
        strategies.append({"name": "Ablation-kl", "mode": "s2t", "trigger_type": "kl"})
    if "random" in strategy_demand:
        strategies.append(
            {
                "name": "Ablation-random",
                "mode": "s2t",
                "trigger_type": "Random",
            })
    if "help" in strategy_demand:
        strategies.append(
            {
                "name": "Ablation-kl-SLMResample-LLMHelp",
                "mode": "s2t",
                "trigger_type": "kl",
                "guidance_type": "SLM_RESAMPLE_WITH_LLM_HELP",
                "slm_multi_k": slm_multi_k,  
            })
    if "nope" in strategy_demand:
        strategies.append(
            {
                "name": "Ablation-kl-SLMResample-NoLLM",
                "mode": "s2t",
                "trigger_type": "kl",
                "guidance_type": "SLM_RESAMPLE_NO_LLM",
                "slm_multi_k": slm_multi_k,
            })
    if "enpe" in strategy_demand:
        strategies.append(
            {
                "name": "Ablation-kl-SLMResample-NoLLM",
                "mode": "s2t",
                "trigger_type": "entropy",
                "guidance_type": "SLM_RESAMPLE_NO_LLM",
                "slm_multi_k": slm_multi_k,
            })
    if "ranpe" in strategy_demand:
        strategies.append(
            {
                "name": "Ablation-random-SLMResample-NoLLM",
                "mode": "s2t",
                "trigger_type": "Random",
                "guidance_type": "SLM_RESAMPLE_NO_LLM",
                "slm_multi_k": slm_multi_k,
            })
    if "trigger_only" in strategy_demand:
        strategies.append({
            "name": "HeaderTrigger_SLMResampleNoLLM",
            "mode": "s2t",
            "trigger_type": "",        
            "use_header_trigger": True,       
            "header_trigger_threshold": args.header_trigger_threshold,   
        })
    if "ziprank" in strategy_demand:
        strategies.append(
            {
                "name": "KLTrigger-ZIPRanking",
                "mode": "s2t",
                "trigger_type": "kl",            
                "guidance_type": "SLM_RESAMPLE_ZIP_RANKING",  
                "slm_multi_k": slm_multi_k,
            })
    return strategies
        
def load_data_for(args):
    if "gsm8k" in args.dataset:
        dataset = load_dataset("gsm8k", "main")
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
    elif "aime" in args.dataset:
        dataset = load_dataset("math-ai/aime25", split="test")
        train_dataset = dataset
        test_dataset = dataset
    elif "math" in args.dataset:
        train_dataset = load_dataset("qwedsacf/competition_math", split="train")
        test_dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
    elif "olympiadbench" in args.dataset:
        train_dataset, test_dataset = load_olympiadbench_split(
            lang=getattr(args, "lang", "en"),      
            text_only=True,                        
            subject=getattr(args, "subject", "math"),
            test_size=getattr(args, "test_size", 0.1),
            seed=getattr(args, "seed", 42),
        )
    elif "mmlu_pro" in args.dataset:
        dataset = load_dataset("TIGER-Lab/MMLU-Pro")
        train_dataset = dataset["test"]
        test_dataset = dataset["validation"]
    train_max_samples = min(args.train_num, len(train_dataset))
    test_max_samples = min(args.test_num, len(test_dataset))
    train_data = train_dataset.select(range(train_max_samples))
    test_data = test_dataset.select(range(test_max_samples))
    return train_data, test_data

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def _to_float(x, default=0.0):
    if x is None:
        return default
    if isinstance(x, torch.Tensor):
        return float(x.detach().cpu().item())
    try:
        return float(x)
    except Exception:
        return default

def _zip_subsample_candidates(batch_train, k_keep: int):
    """
    Reduce candidate dimension K to k_keep to save memory.
    Assumes batch_train contains:
      - input_ids:      [B*K, L]
      - attention_mask: [B*K, L]
      - teacher_abs:    [B, K]
      - cand_mask:      [B, K]
      - B, K (ints)
    Keeps teacher-best candidate always, and samples the rest by teacher_abs mass.
    """
    if k_keep is None:
        return batch_train
    B = int(batch_train["B"])
    K = int(batch_train["K"])
    k_keep = int(k_keep)
    if k_keep <= 0 or k_keep >= K:
        return batch_train

    device = batch_train["input_ids"].device
    cand_mask = batch_train["cand_mask"]  # [B, K]
    teacher_abs = batch_train["teacher_abs"]  # [B, K]

    # choose K_new = min(k_keep, min valid count across batch)
    valid_counts = cand_mask.sum(dim=1).long().clamp_min(0)
    K_new = int(min(k_keep, int(valid_counts.min().item()) if B > 0 else k_keep))
    if K_new <= 0 or K_new >= K:
        return batch_train

    keep_cols = []
    for b in range(B):
        valid = (cand_mask[b] > 0).nonzero(as_tuple=True)[0]
        # fallback if mask is weird
        if valid.numel() == 0:
            valid = torch.arange(K, device=device)
        if valid.numel() <= K_new:
            cols = valid[:K_new]
        else:
            # ensure teacher-best included
            best_local = valid[torch.argmax(teacher_abs[b, valid])]
            rest = valid[valid != best_local]
            if K_new == 1:
                cols = best_local.view(1)
            else:
                weights = teacher_abs[b, rest].clamp_min(1e-12)
                probs = weights / weights.sum()
                pick = torch.multinomial(probs, num_samples=K_new - 1, replacement=False)
                cols = torch.cat([best_local.view(1), rest[pick]], dim=0)
        keep_cols.append(cols)
    keep_cols = torch.stack(keep_cols, dim=0)  # [B, K_new]

    # gather [B, K] -> [B, K_new] tensors
    for key in ["teacher_abs", "cand_mask"]:
        if key in batch_train and batch_train[key].dim() == 2 and batch_train[key].shape[1] == K:
            batch_train[key] = batch_train[key].gather(1, keep_cols)

    # gather flattened sequences [B*K, L] -> [B*K_new, L]
    base = (torch.arange(B, device=device).unsqueeze(1) * K)
    flat_idx = (base + keep_cols).reshape(-1)  # [B*K_new]
    batch_train["input_ids"] = batch_train["input_ids"].index_select(0, flat_idx)
    batch_train["attention_mask"] = batch_train["attention_mask"].index_select(0, flat_idx)

    batch_train["K"] = K_new
    return batch_train

def setup_run_dir(args, ts: str):
    # 1) resolve run_dir
    if args.run_dir:
        run_dir = os.path.abspath(args.run_dir)
    else:
        run_dir = os.path.abspath(os.path.join(args.runs_root, ts))
    os.makedirs(run_dir, exist_ok=True)

    # 2) subdirs
    args.run_dir = run_dir
    args.zip_data_dir   = os.path.join(run_dir, "data")
    args.zip_output_dir = os.path.join(run_dir, "ckpt", "zip_lora")
    args.output_dir     = os.path.join(run_dir, "eval")
    args.zip_debug_dir  = os.path.join(run_dir, "debug")
    args.header_data_dir   = os.path.join(run_dir, "header_data")
    args.header_ckpt = args.header_data_dir + "/header_router.pt"

    for d in [args.zip_data_dir, args.zip_output_dir, args.output_dir, args.zip_debug_dir]:
        os.makedirs(d, exist_ok=True)

    # 3) default glob (train reads directly, no mid processing)
    args.zip_data_glob = os.path.join(args.zip_data_dir, "zip_samples*.jsonl")

    # 4) auto-load lora for eval if exists and user didn't specify
    if (not args.zip_lora_ckpt):
        adapter_cfg = os.path.join(args.zip_output_dir, "adapter_config.json")
        if os.path.exists(adapter_cfg):
            args.zip_lora_ckpt = args.zip_output_dir

    # 5) dump config for reproducibility
    cfg_path = os.path.join(run_dir, "run_config.json")
    with open(cfg_path, "w", encoding="utf-8") as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=2)

    return args

# -----------------------------
# -----------------------------
def load_models_and_assets(args):
    print("[Load Assets] Loading config ...")
    assets, models, fmts = dict(), dict(), dict()
    fmts["tok"] = dict()
    if args.not_llm or args.not_slm:
        args.llm_device = "cuda:0"
    if not args.not_llm:
        print(f"[Load Models] Loading LLM: {args.llm_model} on {args.llm_device}")
        llm = AutoModelForCausalLM.from_pretrained(
            args.llm_model, trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            device_map={"": args.llm_device} if torch.cuda.is_available() else "auto"
        )
        llm.eval()
        models["llm"] = llm
        models["llm_model"] = args.llm_model
        fmts["llm"]= ChatFormatter(args.llm_model)
        fmts["tok"]["llm"] = fmts["llm"].tokenizer
    if not args.not_slm:
        print(f"[Load Models] Loading SLM: {args.slm_model} on {args.slm_device}")
        slm = AutoModelForCausalLM.from_pretrained(
            args.slm_model, trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            device_map={"": args.slm_device} if torch.cuda.is_available() else "cpu"
        )
        slm.eval()

        fmts["slm"] = ChatFormatter(args.slm_model)
        slm_tok = fmts["slm"].tokenizer
        fmts["tok"]["slm"] = slm_tok

        if args.zip_use_lora and args.zip_lora_ckpt:
            print(f"[Load Models] Loading ZIP LoRA from {args.zip_lora_ckpt}")
            slm = PeftModel.from_pretrained(
                slm,
                args.zip_lora_ckpt,
            )
            slm.eval()
        models["slm"] = slm
        models["slm_model"] = args.slm_model

        max_len = 151935
        zip_token_ids = get_zip_token_ids(slm_tok, max_len, num_bins=args.zip_num_bins)
        print("\n" + "="*40)
        print(f"[DEBUG] ZIP Token IDs (Count: {len(zip_token_ids)}): {zip_token_ids}")
        print(f"[DEBUG] Max ZIP ID: {max(zip_token_ids)}")
        print(f"[DEBUG] Original Vocab Size: {len(slm_tok)}")
        print(f"[DEBUG] Model Embedding Size: {max_len}")
        
        try:
            decoded_tokens = [slm_tok.decode([tid]) for tid in zip_token_ids]
            print(f"[DEBUG] Decoded ZIP Tokens: {decoded_tokens}")
        except:
            print("[DEBUG] Could not decode ZIP tokens.")
        print("="*40 + "\n")

        if slm.get_input_embeddings().weight.shape[0] < len(slm_tok):
            print(f"[Model Fix] Resizing token embeddings from {slm.get_input_embeddings().weight.shape[0]} to {len(slm_tok)}")
            slm.resize_token_embeddings(len(slm_tok))

    if not args.not_llm and not args.not_slm:
        slm_tok = fmts["tok"]["slm"]
        llm_tok = fmts["tok"]["llm"]
        
        same_vocab = (getattr(slm_tok, "vocab_size", None) == getattr(llm_tok, "vocab_size", None)) and \
                        (slm_tok.get_vocab() == llm_tok.get_vocab())
        assets["same_vocab"] = same_vocab
        print(f"[Load Toker] same_vocab {same_vocab}...")

    return models, fmts, assets

def masked_group_norm_2d(scores: torch.Tensor, mask: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Per-row normalization over valid candidates only.
    scores: [B, K]
    mask:   [B, K] in {0,1}
    Returns normalized scores with the same shape.
    Note: This preserves within-row ordering (affine transform), but stabilizes softmax scale.
    """
    mask_f = mask.float()
    denom = mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)
    mean = (scores * mask_f).sum(dim=1, keepdim=True) / denom
    var = ((scores - mean) ** 2 * mask_f).sum(dim=1, keepdim=True) / denom
    std = torch.sqrt(var + eps)
    return (scores - mean) / std

def zip_scores_from_logits_bins(
    logits_bins: torch.Tensor,
    zip_label_mode: str,
    score_mode: str = "logits_dot",
    bin_softmax_temp: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Map ZIP head logits over bins -> (scalar score per candidate, probs over bins).
    score_mode:
      - logits_dot:   dot(softmax(logits/temp), linspace(0,1))        (distributional score: E[w] under bin probs)
      scores_1d: [...], higher => better
      probs_bins: [..., num_bins] softmax over bins (for debug/metrics)
    """
    temp = float(bin_softmax_temp) if bin_softmax_temp is not None else 1.0
    temp = max(temp, 1e-6)

    # probs for diagnostics
    probs_bins = F.softmax((logits_bins.float() / temp), dim=-1)

    if zip_label_mode == "topk_binary":
        # num_bins must be 2; score is logit for logits_dot
        if logits_bins.size(-1) != 2:
            raise ValueError(f"topk_binary requires num_bins=2, got {logits_bins.size(-1)}")
        return logits_bins.float()[..., 1], probs_bins

    if zip_label_mode != "prob_bin":
        raise ValueError(f"Unknown zip_label_mode: {zip_label_mode}")

    if score_mode == "logits_dot":
        nb = logits_bins.size(-1)
        w = torch.linspace(0.0, 1.0, nb, device=logits_bins.device, dtype=torch.float32)
        # Use *bin probabilities* (softmax over bins) for a semantically meaningful score.
        scores = (probs_bins * w).sum(dim=-1)
        return scores, probs_bins

    raise ValueError(f"Unknown score_mode: {score_mode}")

# -----------------------------
# -----------------------------
def sample_with_top_kp(logits, k_eff, top_k=20, top_p=0.95, temperature=0.6):
    """
    logits: 1D tensor [N] (or last dim N)
    returns: local indices sampled without replacement, length <= k_eff
    """
    # do not mutate caller tensor
    logits = logits.float().clone()

    # 1) temperature
    if temperature is not None and temperature != 1.0:
        logits = logits / float(temperature)

    # 2) top-k
    if top_k is not None and top_k > 0:
        top_k = min(int(top_k), logits.size(-1))
        v, _ = torch.topk(logits, top_k, dim=-1)
        val_min = v[..., -1].unsqueeze(-1)
        logits = logits.masked_fill(logits < val_min, -float("inf"))

    # 3) top-p
    if top_p is not None and top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        sorted_indices_to_remove = cumulative_probs > float(top_p)
        # shift right to keep the first token above threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False

        indices_to_remove = torch.zeros_like(sorted_indices_to_remove, dtype=torch.bool)
        indices_to_remove.scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, -float("inf"))

    # 4) sample without replacement
    valid_mask = torch.isfinite(logits)
    valid_count = int(valid_mask.sum().item())
    actual_k = min(int(k_eff), valid_count)

    if actual_k < 1:
        # all filtered out -> uniform among all tokens
        probs = torch.ones_like(logits)
        actual_k = 1
    else:
        probs = F.softmax(logits, dim=-1)
        # Safety: if numeric issue, fallback to uniform over valid tokens
        if torch.isnan(probs).any() or float(probs.sum().item()) <= 0.0:
            probs = valid_mask.float()
            s = float(probs.sum().item())
            if s <= 0.0:
                probs = torch.ones_like(logits)
            else:
                probs = probs / s

    sample_local_idx = torch.multinomial(probs, num_samples=actual_k, replacement=False)
    return sample_local_idx

@torch.no_grad()
def generate_answer(
    question_idx,
    not_llm, not_slm,
    data_name, 
    models,
    fmts,
    assets: Dict,
    question: str,
    strategy: Dict,
    output_dir: str,
    K_s = 64,
    N_l = 32,
    C_MAX = 128,
    problem_id: int = 0,
    collect_zip_data: bool = False,
    collect_header_data: bool = False,
    zip_data_dir: Optional[str] = None,
    zip_label_mode: str = "prob_bin",
    header_data_dir: str="",
    random_flipped_target: float=0.1,
) -> Dict:
    def teacher_to_C(llm_logits_full, C_ids):
        llm_probs_full = F.softmax(llm_logits_full, dim=0)
        l_probs_C = llm_probs_full[C_ids]

        mass_Tl = float(l_probs_C.sum().item())
        if mass_Tl > 0:
            l_probs_C_cond = l_probs_C / mass_Tl
        else:
            l_probs_C_cond = torch.zeros_like(l_probs_C)

        return l_probs_C, l_probs_C_cond, mass_Tl

    def _rank_in_student(logits_1d, token_id: int) -> int:
        """1 = highest logit. O(V) compare; called only on triggered steps."""
        v = logits_1d[token_id]
        return 1 + int((logits_1d > v).sum().item())

    if not not_llm:
        llm = models['llm']
        llm_tok = fmts["tok"]["llm"]
    if not not_slm:
        slm = models['slm']
        slm_tok = fmts["tok"]["slm"]
    if not not_llm and not not_slm: 
        same_vocab = assets['same_vocab']

    mode = strategy['mode']
    strategy_name = strategy.get('name', 'Unknown')
    trigger_type = strategy.get('trigger_type', "")  
    guidance_type = strategy.get("guidance_type", "LLM_TAKEOVER")
    slm_multi_k = int(strategy.get("slm_multi_k", 4))
    flipped_target = strategy.get("flipped_target", 0.0)
    sc_k = int(strategy.get("sc_k", 1))

    d_en_history, d_kl_history = [], []
    token_history = []
    triggered_steps, flipped_num = 0, 0
    slm_selected_probs_on_trigger = []
    llm_selected_probs_on_trigger = []
    hit_at_top1 = 0
    step_trace = []

    zip_data_path = None
    f_z = None
    if collect_zip_data:
        os.makedirs(zip_data_dir, exist_ok=True)
        base = f"{strategy.get('name', 'Unknown').replace(' ', '_')}"
        zip_data_path = os.path.join(zip_data_dir, f"zip_samples_new_{data_name}_{base}.jsonl")
        f_z = open(zip_data_path, "a", encoding="utf-8")

    header_trigger_path = None
    if collect_header_data:
        os.makedirs(header_data_dir, exist_ok=True)
        base = f"{strategy.get('name', 'Unknown').replace(' ', '_')}"
        header_trigger_path = os.path.join(header_data_dir, f"trigger_samples_{base}.jsonl")

    cmp_steps_path, f_cmp_steps = None, None
    if guidance_type == "SLM_RESAMPLE_WITH_LLM_HELP":
        os.makedirs(output_dir, exist_ok=True)
        safe_name = re.sub(r"[^a-zA-Z0-9_\-]+", "_", strategy.get("name", "strategy"))
        cmp_steps_path = os.path.join(output_dir, f"klcmp_steps_{data_name}_{safe_name}.jsonl")
        f_cmp_steps = open(cmp_steps_path, "a", encoding="utf-8")

    start_time = time.time()
    
    if "only" in mode:
        if "slm" in mode:
            max_len = 151935
            zip_token_ids = get_zip_token_ids(slm_tok, max_len, num_bins=args.zip_num_bins)
            enc = fmts["slm"].build_inputs(question, data_name)
            input_ids = enc.input_ids.to(slm.device)
            prompt_len = input_ids.shape[1]
            eos_list = resolve_stop_ids(slm_tok)
            attention_mask = enc.attention_mask.to(slm.device) if "attention_mask" in enc else None
            if slm_tok.pad_token_id is None:
                slm_tok.pad_token_id = slm_tok.eos_token_id

            logits_processor = LogitsProcessorList()
            logits_processor.append(ZIPMaskLogitsProcessor(zip_token_ids))

            outs = slm.generate(
                input_ids,
                max_new_tokens=MAX_NEW_TOKENS,
                attention_mask=attention_mask,
                pad_token_id=slm_tok.pad_token_id,
                do_sample=True,  
                eos_token_id=eos_list, 
                temperature=0.7, top_p=0.8, top_k=20,
                num_return_sequences=sc_k,
                logits_processor=logits_processor
            )
            
            if sc_k > 1:  
                seqs = outs[:, prompt_len:]
                total_steps = seqs.shape[1] 
                texts = slm_tok.batch_decode(seqs, skip_special_tokens=True)
                generated_text = [t.strip() for t in texts]
            else:
                generated_ids = outs[0].tolist() 
                total_steps = len(generated_ids) - prompt_len
                text = slm_tok.decode(generated_ids[prompt_len:], skip_special_tokens=True)
                generated_text = [text.strip()]

        elif "llm" in mode: 
            enc = fmts["llm"].build_inputs(question, data_name)
            input_ids = enc.input_ids.to(llm.device)
            attention_mask = enc.attention_mask.to(llm.device)
            prompt_len = input_ids.shape[1]
            
            eos_list = resolve_stop_ids(llm_tok)
            extra_stops = [151645, 151643] 
            for x in extra_stops:
                if x not in eos_list:
                    eos_list.append(x)
            if problem_id == 0:
                print(f"[LLM-Only DEBUG] Using EOS List: {eos_list}")
            if llm_tok.pad_token_id is None:
                llm_tok.pad_token_id = eos_list[0] 

            outs = llm.generate(
                    input_ids,
                    attention_mask=attention_mask,       
                    pad_token_id=llm_tok.pad_token_id if llm_tok.pad_token_id is not None else llm_tok.eos_token_id,
                    max_new_tokens=MAX_NEW_TOKENS,
                    do_sample=True, 
                    temperature=0.6, top_p=0.95, num_return_sequences=1,
                    top_k=None, eos_token_id=eos_list,   
                )
            if sc_k > 1:  
                seqs = outs[:, prompt_len:]
                total_steps = seqs.shape[1] 
                texts = llm_tok.batch_decode(seqs, skip_special_tokens=True)
                generated_text = [t.strip() for t in texts]
            else:
                generated_ids = outs[0].tolist() 
                total_steps = len(generated_ids) - prompt_len
                text = llm_tok.decode(generated_ids[prompt_len:], skip_special_tokens=True)
                generated_text = [text.strip()]
        
        end_time = time.time()
        result = {
        "strategy_name": strategy_name,
        "generated_text": generated_text,
        "total_steps": total_steps,
        "generation_time_sec": end_time - start_time}
        return result

    router = models.get("header_router", None)
    use_header_trigger = bool(strategy.get("use_header_trigger", False) and (router is not None))
    header_trigger_threshold = float(strategy.get("header_trigger_threshold", 0.90))
    router_device = next(router.parameters()).device if router is not None else None
    past_kv = {}

    prompt_slm = fmts["slm"].build_prompt(question, data_name)
    inputs_slm = slm_tok(prompt_slm, return_tensors="pt")
    input_ids  = inputs_slm.input_ids.to(DEVICE)
    prompt_len = input_ids.shape[1]
    generated_ids = input_ids[0].tolist()
    with torch.no_grad():
        out_slm = slm(input_ids=inputs_slm.input_ids.to(slm.device), use_cache=True)
        past_kv["slm"] = out_slm.past_key_values   
    prompt_llm = fmts["llm"].build_prompt(question, data_name)
    inputs_llm = llm_tok(prompt_llm, return_tensors="pt").to(llm.device)
    with torch.no_grad():
        out_llm = llm(input_ids=inputs_llm.input_ids, use_cache=True)
        past_kv["llm"] = out_llm.past_key_values
    eos_list = resolve_stop_ids(slm_tok)
    eos_token_id = slm_tok.eos_token_id
    last_token = generated_ids[-1]

    max_len = 151935
    zip_token_ids = get_zip_token_ids(slm_tok, max_len, num_bins=args.zip_num_bins)
    zip_token_ids_t = torch.tensor(zip_token_ids, device=slm.device, dtype=torch.long)

    for step in range(MAX_NEW_TOKENS):  
        kl_divergence = 0
        slm_entropy = None
        crit_prob = None
        last_token_tensor = torch.tensor([[last_token]], device=DEVICE)

        need_hidden = collect_header_data or use_header_trigger
        slm_out = slm(
            input_ids=last_token_tensor,
            use_cache=True,
            past_key_values=past_kv["slm"],
            output_hidden_states=need_hidden,
        )
        slm_logits_full = slm_out.logits[:, -1, :].squeeze(0).float()
        past_kv["slm"] = slm_out.past_key_values
        if need_hidden and slm_out.hidden_states is not None:
            slm_hidden_t = slm_out.hidden_states[-1][:, -1, :].squeeze(0).detach()
        else:
            slm_hidden_t = None

        slm_logits_full[zip_token_ids_t] = -float('inf')

        inp = torch.tensor([[int(last_token)]], device=llm.device)
        out = llm(input_ids=inp, use_cache=True, past_key_values=past_kv["llm"])
        llm_logits_full = out.logits[:, -1, :].squeeze(0).to(slm.device).float()
        # Keep consistent with Student: never allow ZIP tokens to be proposed by Teacher when vocab is shared
        if (not not_llm) and (not not_slm) and same_vocab:
            llm_logits_full[zip_token_ids_t] = -float("inf")
        past_kv["llm"] = out.past_key_values

        _, slm_topk_ids = torch.topk(slm_logits_full, K_s)
        _, llm_topn_ids = torch.topk(llm_logits_full, N_l)
        C_ids = torch.unique(torch.cat([slm_topk_ids, llm_topn_ids]), sorted=True)

        s_probs = F.softmax(slm_logits_full, dim=0)
        s_probs_C = F.softmax(slm_logits_full[C_ids], dim=0)
        s_probs_alone = F.softmax(slm_logits_full[slm_topk_ids], dim=0)
        l_probs_C, l_probs_C_cond, mass_Tl = teacher_to_C(llm_logits_full, C_ids)
            
        final_dist_C = s_probs_C 

        # Trigger Logic
        WARMUP_STEPS = 3
        if step < WARMUP_STEPS:
            triggered = False
        else:
            if "kl" in trigger_type: 
                mask_tl = l_probs_C > 1e-12
                s_tl = s_probs_C[mask_tl]
                l_tl = l_probs_C[mask_tl]
                s_tl = s_tl / (s_tl.sum() + 1e-12)
                l_tl = l_tl / (l_tl.sum() + 1e-12)
                epsilon = 1e-9
                kl_divergence = torch.sum(s_tl * torch.log((s_tl + epsilon) / (l_tl + epsilon)))
                tau = assets.get("kl_calib_threshold", None)
                tau = float(tau) if tau is not None else float("inf")  
                triggered = bool(kl_divergence > tau)
                d_kl_history.append(float(kl_divergence.item()))

                if (
                    collect_header_data
                    and header_trigger_path is not None
                    and slm_hidden_t is not None
                    and problem_id is not None
                ):
                    y_trig = 1 if kl_divergence > tau else 0
                    trig_sample = {
                        "problem_id": int(problem_id),
                        "step": int(step),
                        "kl_value": float(kl_divergence.item()),
                        "y_trig": int(y_trig),
                        "hidden": slm_hidden_t.cpu().tolist(),
                    }
                    with open(header_trigger_path, "a", encoding="utf-8") as f_ht:
                        f_ht.write(json.dumps(trig_sample, ensure_ascii=False) + "\n")


            if "entropy" in trigger_type:
                slm_entropy = -torch.sum(s_probs_C * torch.log(s_probs_C + 1e-9))
                tau = assets.get("en_calib_threshold", None)
                tau = float(tau) if tau is not None else float("inf") 
                triggered = bool(slm_entropy > tau)
                d_en_history.append(float(slm_entropy.item()))
            
            if trigger_type == "Random":
                triggered = torch.rand(1).item() > random_flipped_target 

            if use_header_trigger and (router is not None) and (slm_hidden_t is not None):
                with torch.no_grad():
                    h_in = slm_hidden_t.to(router_device).unsqueeze(0)  # [1, H]
                    crit_logit = router.predict_criticality(h_in)  # [1, 1]
                    crit_prob = torch.sigmoid(crit_logit).item()
                triggered = bool(crit_prob > header_trigger_threshold)
        if triggered:
            triggered_steps += 1 
            if guidance_type == "LLM_TAKEOVER":
                if (l_probs_C > 0).any():
                    final_dist_C = l_probs_C / (l_probs_C.sum() + 1e-12)
                else:
                    final_dist_C = s_probs_C
            elif "SLM_RESAMPLE" in guidance_type:
                if "ZIP_RANKING" not in guidance_type and not (l_probs_C > 0).any():
                    final_dist_C = s_probs_C
                    selected_idx_C = int(torch.argmax(s_probs_C))
                else:
                    k_eff = min(max(1, slm_multi_k), s_probs_alone.numel())
                    # sample_local_idx = sample_with_top_kp(slm_logits_full[slm_topk_ids], k_eff)
                    sample_local_idx = torch.multinomial(
                        s_probs_alone, num_samples=k_eff, replacement=False
                    )
                    sample_ids = slm_topk_ids[sample_local_idx]
                    sample_idx_in_C = torch.searchsorted(C_ids, sample_ids)

                    if guidance_type == "SLM_RESAMPLE_NO_LLM":
                        teacher_scores = l_probs_C[sample_idx_in_C]  # [k_eff]
                        max_teacher_score = float(teacher_scores.max().item())
                        if (
                            collect_zip_data
                            and float(teacher_scores.sum().item()) > 0.0
                            # and max_teacher_score > 0.05  
                        ):
                            # prefix truncate at collection time to avoid extra processing
                            context_ids_full = list(generated_ids)
                            max_prefix = args.zip_max_len - 1
                            context_ids = context_ids_full[-max_prefix:] if len(context_ids_full) > max_prefix else context_ids_full

                            # normalize (group-wise) for listwise targets (still store abs)
                            t_sum = float(teacher_scores.sum().item())
                            teacher_scores_list = teacher_scores.detach().cpu().tolist()

                            slm_scores_raw = s_probs_C[sample_idx_in_C].detach().float().clamp_min(0.0)
                            slm_scores_cond = (slm_scores_raw / (slm_scores_raw.sum() + 1e-12)).detach().cpu().tolist()

                            rec = {
                                "problem_id": int(problem_id),
                                "step": int(step),
                                "group_id": f"{int(problem_id)}_{int(step)}",
                                "context_ids": context_ids,
                                "cand_ids": sample_ids.detach().cpu().tolist(),
                                "slm_prob_cond": slm_scores_cond,
                                "teacher_prob_abs": teacher_scores_list,
                                "teacher_sum_on_K": float(t_sum),
                                "k_eff": int(sample_ids.numel()),
                            }
                            f_z.write(json.dumps(rec, ensure_ascii=False) + "\n")

                        if float(teacher_scores.sum().item()) <= 0.0:
                            final_dist_C = s_probs_C
                            selected_idx_C = int(torch.argmax(s_probs_C))
                        else:
                            best_local = int(torch.argmax(teacher_scores))
                            selected_idx_C = int(sample_idx_in_C[best_local])
                    
                    elif guidance_type == "SLM_RESAMPLE_ZIP_RANKING":
                        num_bins = len(zip_token_ids)
                        bin_indices = torch.arange(num_bins, device=slm.device, dtype=torch.float32)

                        prefix_list = generated_ids
                        max_ctx = args.zip_max_len - 1
                        if max_ctx > 0 and len(prefix_list) > max_ctx:
                            prefix_list = prefix_list[-max_ctx:]

                        prefix_ids = torch.tensor(
                            prefix_list,
                            device=slm_logits_full.device,
                            dtype=torch.long,
                        ).unsqueeze(0)  # [1, L_prefix]

                        prefix_batch = prefix_ids.expand(k_eff, -1)            # [K, L_prefix]
                        cand_tensor = sample_ids.view(k_eff, 1).to(slm_logits_full.device)  # [K, 1]
                        concat_ids = torch.cat([prefix_batch, cand_tensor], dim=1)          # [K, L_prefix+1]

                        att = torch.ones_like(concat_ids, dtype=torch.long, device=slm_logits_full.device)

                        with torch.no_grad():
                            zip_out = slm(
                                input_ids=concat_ids,
                                attention_mask=att,
                                use_cache=False,   
                            )

                        logits_next = zip_out.logits[:, -1, :]          # [k_eff, V]

                        logits_bins = logits_next[:, zip_token_ids_t]   # [k_eff, num_bins]

                        scores, probs_bins = zip_scores_from_logits_bins(
                            logits_bins=logits_bins,
                            zip_label_mode=zip_label_mode,
                            score_mode=args.zip_score_mode,
                        )
                       
                        if args.zip_score_group_norm:
                            _m = torch.ones((1, scores.numel()), device=scores.device, dtype=torch.float32)
                            scores = masked_group_norm_2d(scores.view(1, -1), _m).view(-1)

                        best_local = int(scores.argmax().item())
                        selected_idx_C = int(sample_idx_in_C[best_local])
                
                    final_dist_C = torch.zeros_like(s_probs_C)
                    final_dist_C[selected_idx_C] = 1.0

                slm_prob_sel = float(s_probs_C[selected_idx_C].item())
                llm_prob_sel = float(l_probs_C[selected_idx_C].item()) if l_probs_C is not None else 0.0
                slm_selected_probs_on_trigger.append(slm_prob_sel)
                llm_selected_probs_on_trigger.append(llm_prob_sel)

            else:
                final_dist_C = s_probs_C

        next_token_idx = torch.argmax(final_dist_C, dim=0) 
        next_token_id = C_ids[next_token_idx]
        generated_ids.append(next_token_id.item())
        last_token = int(next_token_id.item())
        token_history.append(slm_tok.decode([int(next_token_id)]))

        # --- Record
        if triggered:
            with torch.no_grad():
                slm_top_idx_C = int(torch.argmax(s_probs_C))
                slm_top_id = int(C_ids[slm_top_idx_C])
                slm_top_prob = _to_float(s_probs_C[slm_top_idx_C])
                slm_top_str = slm_tok.decode([slm_top_id]).strip()

                llm_top_idx_C = int(torch.argmax(l_probs_C))
                llm_top_id = int(C_ids[llm_top_idx_C])
                llm_top_prob = _to_float(l_probs_C[llm_top_idx_C])
                llm_top_str = llm_tok.decode([llm_top_id]).strip()

                final_top_idx_C = int(torch.argmax(final_dist_C))
                final_top_id = int(C_ids[final_top_idx_C])
                final_top_prob = _to_float(final_dist_C[final_top_idx_C])
                final_top_str = slm_tok.decode([final_top_id]).strip()

            flipped_from_slm = (final_top_id != slm_top_id)
            flipped_num += int(flipped_from_slm)

            # -------------------
            with torch.no_grad():
                mass_s_C = _to_float(s_probs[C_ids].sum())
                mass_l_C_lb = _to_float(l_probs_C.sum()) if l_probs_C is not None else 0.0

            trigger_log = {
                "problem_id": question_idx,
                "step": int(step),
                "strategy": strategy_name,
                "trigger_type": trigger_type,
                "trigger_score": _to_float(kl_divergence) if "entropy" not in trigger_type else _to_float(slm_entropy),
                "slm_top_id": slm_top_id,
                "slm_top_str": slm_top_str,
                "slm_top_prob": slm_top_prob,

                "llm_top_id": llm_top_id,
                "llm_top_str": llm_top_str,
                "llm_top_prob": llm_top_prob,

                "final_top_id": final_top_id,
                "final_top_str": final_top_str,
                "final_top_prob": final_top_prob,

                "flipped_from_slm": bool(flipped_from_slm),

                "mass_s_C": mass_s_C,          
                "mass_l_C_lb": mass_l_C_lb,    
            }

            if "SLM_RESAMPLE" in guidance_type:
                r_teacher = _rank_in_student(llm_logits_full, final_top_id)
                r_slm = _rank_in_student(slm_logits_full, final_top_id)
                r_llm_in_slm = _rank_in_student(slm_logits_full, llm_top_id)
                trigger_log.update({
                    "slm_prob_sel": slm_prob_sel,          
                    "llm_prob_sel": llm_prob_sel, 
                    "slm_rank_sel": r_slm,
                    "llm_rank_sel": r_teacher,
                    "llm_top1_rank_in_slm": r_llm_in_slm,
                })

            # -------------------
            trigger_log_path = os.path.join(
                output_dir,
                f"trigger_log_{strategy_name}_{ts}.jsonl"
            )
            with open(trigger_log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(trigger_log, ensure_ascii=False) + "\n")


        if next_token_id.item() in eos_list:
            break
    total_steps = len(token_history)
    generated_text = slm_tok.decode(generated_ids[prompt_len:], skip_special_tokens=True) 
    end_time = time.time()
    if f_z is not None:
        f_z.close()
    if f_cmp_steps is not None:
        f_cmp_steps.close()

    trigger_rate = (triggered_steps / total_steps) if total_steps > 0 else 0.0
    flip_rate = (flipped_num / total_steps) if total_steps > 0 else 0.0

    result = {
        "strategy_name": strategy_name,
        "generated_text": generated_text,
        "total_steps": total_steps,
        "trigger_count": triggered_steps,
        "hit_at_top1": hit_at_top1,
        "trigger_rate": trigger_rate,
        "flip_rate": flip_rate,
        "flipped_num": flipped_num,
        "generation_time_sec": end_time - start_time,
        "token_history": token_history,
        "d_en_history": d_en_history,
        "d_kl_history": d_kl_history,
        "slm_selected_probs_on_trigger": slm_selected_probs_on_trigger,
        "llm_selected_probs_on_trigger": llm_selected_probs_on_trigger,
    }    
    if step_trace is not None:
        result["step_trace"] = step_trace
    return result

def train_header_router(args, models, fmts):
    data_dir = args.header_data_dir
    trigger_ds = TriggerSampleDataset(data_dir)

    if len(trigger_ds) == 0:
        print("[Header-Train] No trigger/ranking samples found, skip training.")
        return

    slm = models["slm"]
    device = torch.device(args.slm_device if torch.cuda.is_available() else "cpu")

    trig_labels = [int(rec.get("y_trig", 0)) for rec in trigger_ds.samples]
    pos = sum(trig_labels)
    neg = len(trig_labels) - pos
    if pos > 0 and neg > 0:
        pos_weight_value = neg / pos  
    else:
        pos_weight_value = 1.0
    pos_weight = torch.tensor(pos_weight_value, dtype=torch.float32, device=device)
    bce_trig = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    print(
        f"[Header-Train] Trigger samples: total={len(trig_labels)}, "
        f"pos={pos}, neg={neg}, pos_weight={pos_weight_value:.2f}"
    )

    slm_embed = slm.get_output_embeddings().weight.detach().to(device)
    hidden_dim = slm.config.hidden_size
    emb_dim = slm_embed.size(1)

    router = ACG_Router(hidden_dim, emb_dim).to(device)
    router.train()

    trig_len = len(trigger_ds)
    trig_train_len = int(trig_len * 0.9)
    trig_train_ds, trig_val_ds = torch.utils.data.random_split(
        trigger_ds, [trig_train_len, trig_len - trig_train_len],
        generator=torch.Generator().manual_seed(42) 
    )

    # Loader
    trig_train_loader = DataLoader(trig_train_ds, batch_size=args.zip_batch_size, shuffle=True)
    trig_val_loader = DataLoader(trig_val_ds, batch_size=args.zip_batch_size, shuffle=False) 

    pos_count = sum(1 for rec in trigger_ds.samples if float(rec.get("y_trig", 0.0)) > 0.5)
    total_count = max(len(trigger_ds), 1)
    neg_count = max(total_count - pos_count, 1)
    pos_count = max(pos_count, 1)
    pos_weight_val = float(neg_count) / float(pos_count)
    print(f"[Header-Train] Trigger pos_count={pos_count}, neg_count={neg_count}, pos_weight={pos_weight_val:.3f}")

    crit_loss_fn = torch.nn.BCEWithLogitsLoss(
        pos_weight=torch.tensor(pos_weight_val, device=device)
    )
    opt = torch.optim.AdamW(router.parameters(), lr=args.zip_lr)

    FREEZE_TRIGGER_THRESHOLD = 0.15  
    trigger_frozen = False

    best_val_acc = 0.0
    best_epoch = -1
    epochs = args.zip_num_epochs * 10

    os.makedirs(args.header_data_dir, exist_ok=True)
    ckpt_path = args.header_ckpt[:-3] + "_last" + args.header_ckpt[-3:]
    best_ckpt_path = args.header_ckpt[:-3] + args.header_ckpt[-3:]
    
    for epoch in range(epochs):
        # ========== TRAINING ==========
        router.train()
        total_trig_loss = 0.0
        n_trig = 0
        for batch in trig_train_loader:
            h = batch["hidden"].to(device)        # [B, H]
            y = batch["y_trig"].to(device)        # [B]
            opt.zero_grad()
            crit_logits = router.predict_criticality(h)  # [B]
            loss_trig = crit_loss_fn(crit_logits, y)
            loss_trig.backward()
            opt.step()
            total_trig_loss += loss_trig.item() * h.size(0)
            n_trig += h.size(0)
        avg_trig_loss = total_trig_loss / max(n_trig, 1)

        # --- Validation ---
        router.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for batch in trig_val_loader:
                h = batch["hidden"].to(device)
                y = batch["y_trig"].to(device)
                logits = router.predict_criticality(h)
                probs = torch.sigmoid(logits)
                preds = (probs > 0.5).float()
                val_correct += (preds == y).sum().item()
                val_total += y.size(0)
        
        val_acc = val_correct / max(val_total, 1)

        print(f"[Header-Train] Epoch {epoch+1}/{epochs} loss={avg_trig_loss:.4f} val_acc={val_acc*100:.2f}%")

        # --- Save Best ---
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            torch.save({
                "model_state_dict": router.state_dict(),
                "hidden_dim": hidden_dim,
                "emb_dim": emb_dim,
                "val_acc": val_acc
            }, best_ckpt_path)
            # print(f"  * New Best saved to {}")

    print(f"[Header-Train] Training complete. Best Epoch: {best_epoch} with Val Acc: {best_val_acc*100:.2f}%")
    torch.save(
        {
            "model_state_dict": router.state_dict(),
            "hidden_dim": hidden_dim,
            "emb_dim": emb_dim,
        },
        ckpt_path
    )
    print(f"[Header-Train] Saved header router to {ckpt_path}")

def train_zip_distill(args):
    if int(os.environ.get("RANK", "0")) == 0:
        print("[ZIP-Train] Loading tokenizer & models...")
    # --- Optional: multi-GPU DDP (torchrun) ---
    # If launched with torchrun, torch.distributed env vars are set.
    # We auto-enable DDP when WORLD_SIZE>1.
    ddp_world_size = int(os.environ.get("WORLD_SIZE", "1"))
    ddp_enabled = ddp_world_size > 1
    ddp_rank = int(os.environ.get("RANK", "0"))
    ddp_local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
    is_main_process = (ddp_rank == 0)

    if ddp_enabled:
        if not torch.cuda.is_available():
            raise RuntimeError("DDP requires CUDA. Please run on a GPU node.")
        if ddp_local_rank < 0:
            raise RuntimeError("DDP detected (WORLD_SIZE>1) but LOCAL_RANK is missing. Please use torchrun.")
        dist.init_process_group(backend="nccl")
        # --- GPU placement policy for ref_slm under DDP ---
        # default: same-GPU (1 GPU / rank)
        ref_place = getattr(args, "zip_ref_placement", "same")
        if ref_place == "paired":
            # 2 GPUs / rank: train=cuda:(LOCAL_RANK*2), ref=cuda:(LOCAL_RANK*2+1)
            train_id = ddp_local_rank * 2
            ref_id = train_id + 1
            n_gpus = torch.cuda.device_count()
            if ref_id >= n_gpus:
                raise RuntimeError(
                    f"zip_ref_placement=paired requires 2 GPUs per rank. "
                    f"Need at least {ref_id+1} visible GPUs, but only {n_gpus} available. "
                    f"(LOCAL_RANK={ddp_local_rank})"
                )
            torch.cuda.set_device(train_id)
            ddp_train_device_id = train_id
            device_train = f"cuda:{train_id}"
            device_ref = f"cuda:{ref_id}"
        elif ref_place == "cpu":
            torch.cuda.set_device(ddp_local_rank)
            ddp_train_device_id = ddp_local_rank
            device_train = f"cuda:{ddp_local_rank}"
            device_ref = "cpu"
        else:
            torch.cuda.set_device(ddp_local_rank)
            ddp_train_device_id = ddp_local_rank
            device_train = f"cuda:{ddp_local_rank}"
            device_ref = device_train
    else:
        ddp_train_device_id = None
        device_train = args.slm_device if torch.cuda.is_available() else "cpu"
        device_ref = args.llm_device if torch.cuda.is_available() else "cpu"

    # Rank-aware logging helper
    def log0(*a, **k):
        if is_main_process:
            print(*a, **k)

    log0(f"[ZIP-Train] ddp_enabled={ddp_enabled} rank={ddp_rank} local_rank={ddp_local_rank} world_size={ddp_world_size} "
         f"device_train={device_train} device_ref={device_ref}")
    # Make RNG streams different across ranks for better shuffling/regularization.
    set_seed(args.SEED + ddp_rank)
    os.makedirs(args.zip_output_dir, exist_ok=True)

    tok = AutoTokenizer.from_pretrained(args.slm_model, trust_remote_code=True, use_fast=False)
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id

    slm = AutoModelForCausalLM.from_pretrained(
        args.slm_model, trust_remote_code=True, torch_dtype=torch.bfloat16,
        device_map={"": device_train} if torch.cuda.is_available() else "cpu",
    )

    ref_slm = AutoModelForCausalLM.from_pretrained(
        args.slm_model, trust_remote_code=True, torch_dtype=torch.bfloat16,
        device_map={"": device_ref} if torch.cuda.is_available() else "cpu",
    )
    ref_slm.eval()
    for p in ref_slm.parameters(): p.requires_grad = False
    start_epoch = 0
    global_step = 0
    resume_step_idx = -1 
    resume_opt_state = None

    checkpoint_dir = os.path.join(args.zip_output_dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    state_file_path = os.path.join(checkpoint_dir, "latest_trainer_state.pt")
    
    if args.zip_use_lora and args.zip_lora_ckpt:
        log0(f"[ZIP-Train] Resuming/Loading existing ZIP LoRA from {args.zip_lora_ckpt} ...")
        slm = PeftModel.from_pretrained(
            slm, 
            args.zip_lora_ckpt, 
            is_trainable=True  
        )
    else:
        if os.path.exists(state_file_path):
            log0(f"[ZIP-Train] Found checkpoint state at {state_file_path}, resuming...")
            train_state = torch.load(state_file_path, map_location="cpu")
            
            log0("[ZIP-Train] NEW START: Ignoring saved epoch/step/batch_idx from checkpoint.")
            start_epoch = 0
            global_step = 0
            resume_step_idx = -1 
            
            log0(f"[ZIP-Train] Loading model weights from {checkpoint_dir}...")
            try:
                slm = PeftModel.from_pretrained(slm, checkpoint_dir, is_trainable=True)
            except Exception as e:
                slm = add_lora_to_slm(
                    slm, 
                    r=args.zip_lora_r, 
                    alpha=args.zip_lora_alpha, 
                    dropout=args.zip_lora_dropout
                )

                safe_path = os.path.join(checkpoint_dir, "adapter_model.safetensors")
                bin_path = os.path.join(checkpoint_dir, "adapter_model.bin")
                
                adapters_weights = None
                if os.path.exists(safe_path):
                    print(f"[ZIP-Train] Detected safetensors format: {safe_path}")
                    from safetensors.torch import load_file
                    adapters_weights = load_file(safe_path)
                elif os.path.exists(bin_path):
                    print(f"[ZIP-Train] Detected bin format: {bin_path}")
                    adapters_weights = torch.load(bin_path, map_location="cpu")
                else:
                    print(f"[ZIP-Train] [ERROR] No adapter model found in {checkpoint_dir}")
                    print(f"[ZIP-Train] Content of dir: {os.listdir(checkpoint_dir)}")
                    raise FileNotFoundError(f"Could not find adapter_model.safetensors or .bin in {checkpoint_dir}")

                load_result = set_peft_model_state_dict(slm, adapters_weights)
                print(f"[ZIP-Train] LoRA weights loaded successfully. Result: {load_result}")
                # =========================================================

                resume_opt_state = train_state.get("optimizer_state_dict", None)

            log0(f"[ZIP-Train] Resumed from Epoch {start_epoch}, Global Step {global_step}, Batch Index {resume_step_idx}")

        else:
            log0("[ZIP-Train] Enabling LoRA on SLM...")
            slm = add_lora_to_slm(
                slm,
                r=args.zip_lora_r,
                alpha=args.zip_lora_alpha,
                dropout=args.zip_lora_dropout,
            )

    # 4.5) ZIP reserved bins: restrict lm_head training to only rows in R
    try:
        model_vocab_size = int(slm.get_output_embeddings().weight.shape[0])
    except Exception:
        model_vocab_size = int(slm.config.vocab_size)

    zip_token_ids = get_zip_token_ids(tok, model_vocab_size - 1, num_bins=args.zip_num_bins)
    zip_token_ids_t = torch.tensor(zip_token_ids, device=device_train, dtype=torch.long)
    log0(f"[ZIP-Train] Trainable params after row-wise lm_head tuning: {_count_trainable_params(slm):,}")

    if hasattr(slm, "module"): 
        lm_head_layer = slm.module.lm_head
    else:
        lm_head_layer = slm.lm_head

    lm_head_layer.weight.requires_grad = True
    def scrub_lm_head_grad_hook(grad):
        kept_grads = grad[zip_token_ids_t].clone()
        grad.zero_() 
        grad.index_copy_(0, zip_token_ids_t, kept_grads) 
        
        return grad

    hook_handle = lm_head_layer.weight.register_hook(scrub_lm_head_grad_hook)
    
    if is_main_process:
        print(f"[ZIP-Train] Registered Gradient Mask on lm_head. "
              f"Only {len(zip_token_ids)} rows will be updated.")

    log0(f"[ZIP-Train] Loading ZIPGroupDataset (group-wise) from: {args.zip_data_glob}")
    dataset = ZIPGroupDataset(
        data_glob=args.zip_data_glob,
        num_bins=args.zip_num_bins,
        max_len=args.zip_max_len,
        expected_k=args.slm_multi_k,
    )
    log0(f"\n{'='*40}")
    log0(f"Set info: Training dataset size: {len(dataset)}")
    log0(f"{'='*40}\n")

    collate = lambda batch: zip_group_collate_fn(batch, pad_token_id=pad_id, num_bins=args.zip_num_bins)

    sampler = None
    if ddp_enabled:
        sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)

    loader = DataLoader(
        dataset,
        batch_size=args.zip_batch_size,
        shuffle=(sampler is None),
        sampler=sampler,
        num_workers=2,
        collate_fn=collate,
        pin_memory=torch.cuda.is_available(),
    )
    log0(f"[ZIP-Train] zip_num_bins {args.zip_num_bins}...")

    # 6) Optimizer
    trainable_params = [p for p in slm.parameters() if p.requires_grad]
    
    lm_head_weight = lm_head_layer.weight
    other_params = [p for p in trainable_params if p is not lm_head_weight]
    
    param_groups = [
        {"params": other_params, "weight_decay": 0.01}, 
        {"params": [lm_head_weight], "weight_decay": 0.0} 
    ]

    opt = torch.optim.AdamW(param_groups, lr=args.zip_lr)

    # If resuming, load optimizer state now that opt exists.
    if resume_opt_state is not None:
        try:
            opt.load_state_dict(resume_opt_state)
            # move optimizer state tensors to local device
            for state in opt.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device_train)
            log0("[ZIP-Train] Optimizer state loaded.")
        except Exception as e:
            log0(f"[ZIP-Train] [WARN] Failed to load optimizer state: {e}")

    slm.train()

    # Wrap with DDP after LoRA is attached and optimizer is created.
    if ddp_enabled:
        slm = DDP(
            slm,
            device_ids=[ddp_train_device_id],
            output_device=ddp_train_device_id,
        )
    model_to_save = slm.module if hasattr(slm, "module") else slm
    # avoid caching KV during training
    try:
        slm.config.use_cache = False
    except Exception:
        pass

    bin_vals = torch.arange(args.zip_num_bins, device=device_train).float()

    try:
        for epoch in range(args.zip_num_epochs):
            slm.train()

            running_loss_zip = 0.0
            running_loss_kl = 0.0
            running_loss_cap = 0.0
            log_every = 10
            running_zip_acc = 0.0

            accum = int(args.zip_grad_accum_steps or 1)
            accum = max(accum, 1)
            opt.zero_grad(set_to_none=True)
            accum_i = 0

            for batch_idx, batch in enumerate(loader):
                batch_train = {k: (v.to(device_train) if isinstance(v, torch.Tensor) else v)
                    for k, v in batch.items()}
                
                out = slm(
                    input_ids=batch_train["input_ids"],
                    attention_mask=batch_train["attention_mask"],
                    labels=None,
                    use_cache=False,
                )
                logits = out.logits  # [B, L, V]

                att_mask = batch_train["attention_mask"]  # [B, L]
                lengths = att_mask.sum(dim=1) - 1  # [B]
                bsz = logits.size(0)
                batch_indices_tensor = torch.arange(bsz, device=logits.device)

                logits_last = logits[batch_indices_tensor, lengths, :]         # [B*K, V]
                logits_bins = logits_last[:, zip_token_ids_t]       # [B*K, num_bins]
                # ---- 1) scalar score for each candidate (for listwise) ----
                # continuous score channel (avoid E[bin] one-hot collapse):
                #   - logits_dot:   dot(logits, linspace(0,1))            (continuous)
                scores_1d, probs_bins = zip_scores_from_logits_bins(
                    logits_bins=logits_bins,
                    zip_label_mode=args.zip_label_mode,
                    score_mode=args.zip_score_mode,
                )
                scores = scores_1d.view(batch_train["B"], batch_train["K"])            # [B, K]

                # ---- 2) teacher distribution (absolute probs on candidate set C) ----
                cand_mask = batch_train["cand_mask"].float()        # [B, K] in {0,1}
                teacher_abs = batch_train["teacher_abs"].float()    # [B, K] >=0 (not necessarily summing to 1)
                # optional group-wise normalization (does not change argmax ordering within group)
                if args.zip_score_group_norm:
                    scores = masked_group_norm_2d(scores, cand_mask)

                # IMPORTANT: all candidate-softmax / cap / logit-reg should work in the SAME z-space:
                cand_T = max(args.zip_cand_softmax_temp, 1e-6)
                scores_eff = scores / cand_T   # [B,K] (unmasked effective logits)
                scores_eff_masked = scores_eff.masked_fill(cand_mask <= 0, -1e9)
                # normalize teacher to a per-group distribution (only over valid candidates)
                teacher_mass = teacher_abs * cand_mask
                teacher_mass_sum = teacher_mass.sum(dim=1, keepdim=True).clamp_min(1e-12)
                p_teacher = teacher_mass / teacher_mass_sum

                # ---- 3) teacher confidence handling ("only learn from a confident teacher") ----
                # When teacher is uncertain (top1 ~ top2), hard supervision (e.g., top1 CE) is noisy.
                # We optionally (a) drop those groups, or (b) down-weight them.
                with torch.no_grad():
                    if p_teacher.size(1) >= 2:
                        topv, topi = torch.topk(p_teacher, k=2, dim=1)
                        t1 = topv[:, 0]
                        t2 = topv[:, 1]
                        teacher_best = topi[:, 0]   # [B]
                        teacher_second = topi[:, 1] # [B]
                    else:
                        t1, teacher_best = torch.max(p_teacher, dim=1)
                        t2 = torch.zeros_like(t1)
                        teacher_second = teacher_best

                    teacher_margin = (t1 - t2).clamp_min(0.0)
                    conf_mask = torch.ones_like(teacher_margin, dtype=torch.bool)
                    if args.zip_teacher_min_margin > 0.0:
                        conf_mask = teacher_margin >= args.zip_teacher_min_margin

                    # group weights
                    if args.zip_teacher_filter_mode == "none":
                        group_w = torch.ones_like(teacher_margin)
                    else:
                        eps_w = float(getattr(args, "zip_teacher_weight_eps", 1e-3))
                        pow_w = float(getattr(args, "zip_teacher_margin_power", 1.0))
                        base_w = (teacher_margin + eps_w).pow(max(pow_w, 0.0))
                        if args.zip_teacher_filter_mode == "drop":
                            group_w = base_w * conf_mask.float()
                        else:  # "weight"
                            group_w = base_w

                    w_denom = group_w.sum().clamp_min(1e-12)

                    # debug-only summary stats (useful when filtering by teacher confidence)
                    teacher_margin_mean = float(teacher_margin.mean().item())
                    teacher_keep_ratio = float(conf_mask.float().mean().item())

                # student distribution over candidates (masked)
                p_student = torch.softmax(scores_eff_masked, dim=-1)

                # teacher-student KL term (can be primary or auxiliary)
                kl_listwise = (p_teacher * (torch.log(p_teacher.clamp_min(1e-12)) - torch.log(p_student.clamp_min(1e-12)))).sum(dim=1)  # [B]
                kl_mean = (group_w * kl_listwise).sum() / w_denom

                loss_mode = args.zip_listwise_loss
                kl_w = args.zip_teacher_kl_weight
                if loss_mode == "top1_ce":
                    ce = F.cross_entropy(scores_eff_masked, teacher_best, reduction="none")  # [B]

                    pw = torch.zeros_like(ce)
                    if args.zip_pairwise_margin > 1e-9 and scores_eff_masked.size(1) >= 2:
                        s_target = scores_eff.gather(1, teacher_best.unsqueeze(1)).squeeze(1)  # [B]
                        neg_scores = scores_eff_masked.clone()
                        neg_scores.scatter_(1, teacher_best.unsqueeze(1), -1e9)
                        s_hard_neg, _ = neg_scores.max(dim=1)  # [B]

                        with torch.no_grad():
                            if args.zip_score_group_norm:
                                row_std = torch.ones_like(s_target)
                            else:
                                cm = cand_mask.float()
                                denom = cm.sum(dim=1).clamp_min(1.0)
                                mean = (scores * cm).sum(dim=1) / denom
                                var = ((scores - mean.unsqueeze(1)) ** 2 * cm).sum(dim=1) / denom
                                row_std = torch.sqrt(var + 1e-6)

                            margin_eff = (float(args.zip_pairwise_margin) * row_std) / cand_T

                        pw = F.relu(margin_eff - (s_target - s_hard_neg))  # [B]

                    loss_per_sample = ce + pw
                    loss_listwise = (group_w * loss_per_sample).sum() / w_denom
                
                elif loss_mode == "pairwise_margin":
                    # Pairwise margin on teacher top-1 vs runner-up ("ceasefire" beyond margin)
                    margin = args.zip_pairwise_margin
                    margin = max(margin, 0.0)
                    if scores_eff_masked.size(1) < 2:
                        loss_listwise = torch.zeros((), device=device_train)
                    else:
                        with torch.no_grad():
                            topv, topi = torch.topk(p_teacher, k=2, dim=1)
                            best_idx = topi[:, 0]
                            second_idx = topi[:, 1]
                        s_best = scores_eff.gather(1, best_idx.unsqueeze(1)).squeeze(1)
                        s_second = scores_eff.gather(1, second_idx.unsqueeze(1)).squeeze(1)
                        gap = s_best - s_second
                        loss_pair = F.relu(margin - gap)
                        loss_listwise = (group_w * loss_pair).sum() / w_denom

                else:
                    loss_listwise = (group_w * kl_listwise).sum() / w_denom
                    
                if loss_mode == "kl":
                    # Listwise KL(p_teacher || p_student)
                    loss_listwise = kl_w * loss_listwise
                elif kl_w > 0.0:
                    loss_listwise = loss_listwise + kl_w * kl_mean

                # diagnostics: top-1 agreement (listwise "accuracy")
                with torch.no_grad():
                    pred_best = torch.argmax(scores_eff_masked, dim=1)
                    zip_hit1 = (pred_best == teacher_best).float().mean().item()

                with torch.no_grad():
                    input_ids_ref = batch_train["input_ids"].to(device_ref) if device_ref != device_train else batch_train["input_ids"]
                    att_mask_ref  = batch_train["attention_mask"].to(device_ref) if device_ref != device_train else batch_train["attention_mask"]
                    ref_out = ref_slm(input_ids=input_ids_ref, attention_mask=att_mask_ref)

                ref_logits = ref_out.logits.detach()
                ref_logits_on_train = ref_logits.to(device_train)

                log_p_ref = F.log_softmax(ref_logits_on_train, dim=-1)
                log_p_new = F.log_softmax(logits, dim=-1)
                p_ref = log_p_ref.exp()

                kl_token = (p_ref * (log_p_ref - log_p_new)).sum(dim=-1)

                kl_mask = att_mask.clone()  # [B, L]
                kl_mask[batch_indices_tensor, lengths] = 0
                kl = (kl_token * kl_mask).sum() / (kl_mask.sum() + 1e-9)

                # ---- 4) optional: margin cap regularizer (tie / "ceasefire" on ambiguous groups) ----
                # Intuition: when teacher is uncertain (small top1-top2 margin), discourage student from
                # becoming over-confident (huge score gap). This reduces "confident wrong" damage online.
                loss_margin_cap = torch.tensor(0.0, device=device_train)
                if args.zip_margin_cap_lambda > 0.0 and scores_eff_masked.size(1) >= 2:
                    with torch.no_grad():
                        # teacher margin on normalized teacher distribution
                        k_top = 2 if p_teacher.size(1) >= 2 else 1
                        tv = torch.topk(p_teacher, k=k_top, dim=1).values
                        tm1 = tv[:, 0]
                        tm2 = tv[:, 1] if k_top == 2 else torch.zeros_like(tm1)
                        teacher_margin = (tm1 - tm2).clamp(min=0.0)  # [B]
                        tau = args.zip_margin_cap_tau
                        tau = max(tau, 1e-6)
                        cap_min = args.zip_margin_cap_min
                        cap_max = args.zip_margin_cap_max
                        alpha = torch.clamp(teacher_margin / tau, 0.0, 1.0)
                        cap = cap_min + (cap_max - cap_min) * alpha  # [B]
                    # student margin in score space
                    student_top2 = torch.topk(scores_eff_masked, k=2, dim=1).values
                    stud_margin = (student_top2[:, 0] - student_top2[:, 1]).clamp(min=0.0)  # [B]
                    valid_cnt = cand_mask.sum(dim=1)
                    stud_margin = torch.where(valid_cnt >= 2.0, stud_margin, torch.zeros_like(stud_margin))
                    loss_margin_cap = F.relu(stud_margin - cap).mean()

                loss = (
                    loss_listwise
                    + args.zip_kl_lambda * kl
                    + args.zip_margin_cap_lambda * loss_margin_cap
                )
                # gradient accumulation: scale loss
                # In DDP, avoid gradient sync on micro-steps (saves bandwidth/time).
                sync_ctx = contextlib.nullcontext()
                if ddp_enabled and hasattr(slm, "no_sync") and (accum_i % accum != accum - 1):
                    sync_ctx = slm.no_sync()
                with sync_ctx:
                    (loss / accum).backward()
                accum_i += 1
                do_step = (accum_i % accum == 0)

                running_loss_zip += loss_listwise.item()
                running_loss_kl += kl.item()
                running_loss_cap += float(loss_margin_cap.item()) 
                running_zip_acc += zip_hit1  

                if do_step:
                    global_step += 1
                    torch.nn.utils.clip_grad_norm_(slm.parameters(), args.zip_max_grad_norm)
                    opt.step()
                    opt.zero_grad(set_to_none=True)
                    
                if do_step and global_step % args.zip_save_steps == 0:
                    if is_main_process:
                        log0(f"\n[ZIP-Train] Saving checkpoint at step {global_step}...")
                        # 1) Save LoRA weights
                        model_to_save.save_pretrained(checkpoint_dir)
                        # 2) Save training state (optimizer, epoch, step)
                        save_state = {
                            'epoch': epoch,
                            'global_step': global_step,
                            'batch_idx_in_epoch': batch_idx,
                            'optimizer_state_dict': opt.state_dict(),
                        }
                        gc.collect()
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

                        torch.save(save_state, state_file_path)
                        log0(f"[ZIP-Train] Checkpoint saved to {checkpoint_dir} (Opt skipped)")

                    if ddp_enabled:
                        dist.barrier()

                if do_step and global_step % log_every == 0:
                    divs = float(log_every) * accum
                    avg_zip = running_loss_zip / divs
                    avg_kl  = running_loss_kl  / divs * args.zip_kl_lambda
                    avg_cap = running_loss_cap / divs * args.zip_margin_cap_lambda
                    avg_acc = running_zip_acc / divs

                    if is_main_process:
                        print(
                            f"global_step {global_step} / {epoch}: "
                            f"loss_listwise={avg_zip:.4f}, loss_kl={avg_kl:.4f}, "
                            f"loss_cap={avg_cap:.4f}, zip_bin_acc={avg_acc:.4f}"
                        )

                    running_loss_zip = 0.0
                    running_loss_kl  = 0.0
                    running_loss_cap = 0.0
                    running_zip_acc  = 0.0

                    kl_at_action = kl_token[batch_indices_tensor, lengths].mean().item()
                    kl_at_prefix = (kl_token * kl_mask).sum().item() / (kl_mask.sum().item() + 1e-9)
                    if is_main_process:
                        print(f"[ZIPDBG] teacher_margin_mean={teacher_margin_mean:.4f}  kept_ratio={teacher_keep_ratio:.3f}")
                        print(f"[ZIPDBG] KL_at_action(unmasked raw)={kl_at_action:.4f}  KL_prefix(masked avg)={kl_at_prefix:.4f}")

                    with torch.no_grad():
                        # Reuse probs_bins from zip_scores_from_logits_bins so debug uses the *same* bin_softmax_temp.
                        pb = probs_bins.detach()
                        ent = -(pb.clamp_min(1e-12) * pb.clamp_min(1e-12).log()).sum(dim=-1)
                        top1 = pb.max(dim=-1).values

                    if is_main_process:
                        print(
                            f"[ZIPDBG] bins_entropy={ent.mean().item():.3f} (max={math.log(probs_bins.size(-1)):.3f})  "
                            f"top1_prob={top1.mean().item():.3f}"
                        )

            if is_main_process:
                out_path = os.path.join(args.zip_output_dir, f"slm_zip_inter_{epoch}")
                if not os.path.exists(out_path):
                    os.makedirs(out_path, exist_ok=True)
                model_to_save.save_pretrained(out_path)
                print(f"[ZIP-Train] Finished epoch {epoch} and intermediate lora saved to {out_path}")

            if is_main_process:
                save_state = {
                    'epoch': epoch + 1, 
                    'global_step': global_step,
                    'batch_idx_in_epoch': -1, 
                    # 'optimizer_state_dict': opt.state_dict()
                }
                torch.save(save_state, state_file_path)
            if ddp_enabled:
                dist.barrier()

    except torch.OutOfMemoryError as e:
        if is_main_process:
            print("\n" + "!"*50)
            print(f"[ZIP-Train][CRITICAL] OutOfMemoryError: {e}")        

        if ddp_enabled:
            dist.barrier()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        raise e

    finally:
        if is_main_process:
            model_to_save.save_pretrained(args.zip_output_dir)
            print(f"[ZIP-Train] LoRA checkpoint saved to {args.zip_output_dir}")

        if ddp_enabled:
            # Ensure rank0 finishes writing before other ranks exit.
            dist.barrier()
            dist.destroy_process_group()

# -----------------------------
# -----------------------------
def run_evaluation(args, models, fmts, assets, strategy, test_data, start_idx=0):
    strategy_name = strategy['name']
    print(f"\n--- Running Strategy: {strategy_name} ---")

    results_path = os.path.join(args.output_dir, f"results_{strategy_name.replace(' ', '_')}_{ts}.jsonl")
    all_results, all_per_sample_total_steps = [], []
    all_per_sample_trigger_steps = []
    all_per_sample_flipped_steps, all_per_sample_flipped_rate = [], []
    all_llm_selected_probs_on_trigger, all_slm_selected_probs_on_trigger = [], []
    correct_count = 0
    all_hit_at_top1 = []

    with open(results_path, "w", encoding="utf-8") as f:
        for i, example in enumerate(tqdm(test_data, desc=f"Strategy {strategy_name}")):
            if i < start_idx:
                continue

            question = parse_question(example, args.dataset)
            _, ans = parse_ground_truth(example, args.dataset)
            
            result = generate_answer(
                i, args.not_llm, args.not_slm, args.dataset, models, fmts, assets, question, strategy, args.output_dir, K_s=args.K_s, 
                N_l=args.N_l, problem_id=i, 
                collect_zip_data=args.collect_zip_data,
                collect_header_data=args.collect_header_data,
                zip_data_dir=args.zip_data_dir,
                zip_label_mode=args.zip_label_mode,
                header_data_dir=args.header_data_dir,
                random_flipped_target=args.random_flipped_target,
            )
            
            gt = result['generated_text']
            texts = gt if isinstance(gt, list) else [gt]
            
            is_correct_count = 0
            is_correct = False
            for text in texts:
                pred_text = extract_answer(text)
                if math_equal(pred_text, ans):
                    is_correct = True
                    break

            if is_correct:
                correct_count += 1

            result['problem_id'] = f"{args.dataset}_{i}"
            result['question'] = question
            result['ans'] = ans
            result['is_correct'] = is_correct
            f.write(json.dumps(result) + "\n")

            all_results.append(result)
            all_per_sample_trigger_steps.append(result.get('trigger_count', 0))
            all_per_sample_total_steps.append(result.get('total_steps', 0))
            all_per_sample_flipped_steps.append(result.get("flipped_num", 0))
            all_per_sample_flipped_rate.append(result.get("flip_rate", 0))
            all_llm_selected_probs_on_trigger.extend(result.get("llm_selected_probs_on_trigger", []))
            all_slm_selected_probs_on_trigger.extend(result.get("slm_selected_probs_on_trigger", []))
            all_hit_at_top1.append(result.get("hit_at_top1", 0))
            question = question.replace("\n", " ")
            gt = result['generated_text']
            if isinstance(gt, list):
                dbg_text = gt[0] if gt else ""
            else:
                dbg_text = gt

            dbg_text_flat = dbg_text.replace("\n", " ")
            print(f"\nPROBLEM: {question}")
            print(f"GROUND TRUTH: {ans}")
            print(f"GENERATED ANSWER:{dbg_text_flat[-80:]} ======> {is_correct}")
            print("-------------------------------")
                
    total_count = len(all_results)
    accuracy = (correct_count / total_count) * 100
    avg_token = np.sum(all_per_sample_total_steps) / len(all_per_sample_total_steps)
    avg_time = np.mean([r['generation_time_sec'] for r in all_results])

    if "only" in strategy_name:
        print(f"--- Strategy Summary: {strategy_name} on {args.dataset} ---")
        print(f"  Accuracy: {accuracy:.2f}% ({correct_count}/{total_count})")
        print(f"  Avg. Token: {avg_token: .2f}")
        print(f"  Avg. Time per Answer: {avg_time:.2f} sec")
    else:
        avg_trigger_rate = np.mean([r['trigger_rate'] for r in all_results]) * 100
        avg_trigger_flip_step_rate = 100.0 * np.sum(all_per_sample_flipped_steps) / np.sum(all_per_sample_total_steps) 
        avg_flip_sample = 100 * np.mean(all_per_sample_flipped_rate)
        std_flip_sample = 100 * np.std(all_per_sample_flipped_rate)
        avg_hit_at_top1 = 100 * np.sum(all_hit_at_top1) / np.sum(all_per_sample_trigger_steps) 

        print(f"--- Strategy Summary: {strategy_name} on {args.dataset} with ({args.ablation_quantile}, {args.slm_multi_k}) ---")
        print(f"  Accuracy: {accuracy:.2f}% ({correct_count}/{total_count})")
        print(f"  Avg. Token: {avg_token: .2f}")
        print(f"  Avg. Trigger Rate: {avg_trigger_rate:.2f}%")
        print(f"  Avg. Flip Step: {avg_trigger_flip_step_rate:.2f}%")
        print(f"  Avg. Flip Ratio: {avg_flip_sample:.2f} +- {std_flip_sample:.3f}")
        print(f"  Avg. Hit Ratio: {avg_hit_at_top1:.2f}")
        print(f"  Avg. Time per Answer: {avg_time:.2f} sec")
        if len(all_llm_selected_probs_on_trigger):
            avg_llm_probs, std_llm_probs = np.mean(all_llm_selected_probs_on_trigger), np.std(all_llm_selected_probs_on_trigger)
            avg_slm_probs, std_slm_probs = np.mean(all_slm_selected_probs_on_trigger), np.std(all_slm_selected_probs_on_trigger)
            print(f"  Avg. LLM Probs on Selected: {avg_llm_probs:.2f} +- {std_llm_probs:.2f}")
            print(f"  Avg. SLM Probs on Selected: {avg_slm_probs:.2f} +- {std_slm_probs:.2f}")
    
    print(f"  Results saved to: {results_path}")
    # summarize_zipdebug(results_path)
    print("-------------------------------------------\n")

def run_calibration(args, models, fmts, assets, calib_data, 
                    need_kl_calib=False, need_en_calib=False):
    calib_vals_en, calib_vals_kl = [], []
    calib_assets = dict()

    trigger_parts = []
    if need_kl_calib:
        trigger_parts.append("kl")
    if need_en_calib:
        trigger_parts.append("entropy")

    trigger_str = " ".join(trigger_parts).strip() or "Random"
    print(f"[Calibration] Collecting on {trigger_str}...")

    for i, example in enumerate(tqdm(calib_data, desc="Calib distance")):
        question = parse_question(example, args.dataset)
        
        res = generate_answer(
            i, args.not_llm, args.not_slm, args.dataset,
            models, fmts, assets, question,
            strategy={"name": "CALIB", "mode": "s2t", "trigger_type": trigger_str},
            output_dir=args.output_dir, K_s=args.K_s, N_l=args.N_l,
            zip_label_mode=args.zip_label_mode,
        )
        if need_kl_calib:
            calib_vals_kl.extend(res.get("d_kl_history", []))
        if need_en_calib:
            calib_vals_en.extend(res.get("d_en_history", []))
            
    if len(calib_vals_kl) == 0:
        print("[Calibration][WARN] No kl-distance collected; ablation will be disabled.")
        calib_assets["kl_calib_threshold"] = None
    else:
        thr = float(np.quantile(np.array(calib_vals_kl, dtype=float), args.ablation_quantile))
        calib_assets["kl_calib_threshold"] = thr
        print(f"[Calibration] kl-distance threshold (q={args.ablation_quantile:.2f}) = {thr:.6f}")
    
    if len(calib_vals_en) == 0:
        print("[Calibration][WARN] No entropy-distance collected; ablation will be disabled.")
        calib_assets["en_calib_threshold"] = None
    else:
        thr = float(np.quantile(np.array(calib_vals_en, dtype=float), args.ablation_quantile))
        calib_assets["en_calib_threshold"] = thr
        print(f"[Calibration] entropy-distance threshold (q={args.ablation_quantile:.2f}) = {thr:.6f}")
    
    return calib_assets

# -----------------------------
# -----------------------------
if __name__ == "__main__": 
    global MAX_NEW_TOKENS
    
    parser = argparse.ArgumentParser(description="Preprocessing and Inference Demo")
    parser.add_argument("--SEED", type=int, default=2025)
    parser.add_argument("--output_dir", type=str, default="./assets_math")
    parser.add_argument("--dataset", type=str, default="math")
    
    parser.add_argument("--strategy_demand", type=str, default="")
    parser.add_argument("--slm_model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--llm_model", type=str, default="Qwen/Qwen2.5-32B-Instruct")
    parser.add_argument("--slm_device", type=str, default="cuda:0")
    parser.add_argument("--llm_device", type=str, default="cuda:1")
    
    parser.add_argument("--not_llm", action="store_true")
    parser.add_argument("--not_slm", action="store_true")

    parser.add_argument("--collect_zip_data", action="store_true")
    parser.add_argument("--collect_header_data", action="store_true")
    parser.add_argument("--header_trigger_threshold", type=float, default=0.5)
    parser.add_argument("--zip_bin_scale", type=str, default="logit",
                        choices=["linear", "logit"],
                        help="How to map teacher_prob_cond into bins.")                   

    parser.add_argument("--train_num", type=int, default=1000)
    parser.add_argument("--test_num", type=int, default=300)
    parser.add_argument("--K_s", type=int, default=64)
    parser.add_argument("--N_l", type=int, default=64)
    parser.add_argument("--slm_multi_k", type=int, default=16)
    parser.add_argument("--sc_k", type=int, default=5)
    parser.add_argument("--start_idx", type=int, default=0)

    parser.add_argument("--ablation_calib_num", type=int, default=20)
    parser.add_argument("--ablation_quantile", type=float, default=0.95)
    
    parser.add_argument("--runs_root", type=str, default="./runs")
    parser.add_argument("--run_dir", type=str, default="",
                        help="Root folder of this run. If empty, auto-create runs/<ts>.")

    parser.add_argument("--zip_label_mode", type=str, default="prob_bin",
                        choices=["prob_bin", "topk_binary"],
                        help="prob_bin: teacher_prob_cond -> bins; topk_binary: teacher top-k=1 else 0.")
    parser.add_argument("--zip_topk", type=int, default=1,
                        help="When zip_label_mode=topk_binary, mark teacher top-k as positive.")

    parser.add_argument("--flipped_target", type=float, default=0.0)
    parser.add_argument("--random_flipped_target", type=float, default=0.0)
    parser.add_argument("--train_header", action="store_true")
    parser.add_argument("--train_zip", action="store_true")
    parser.add_argument("--zip_num_bins", type=int, default=16)
    parser.add_argument("--zip_max_len", type=int, default=512)
    parser.add_argument("--zip_lr", type=float, default=5e-5)
    parser.add_argument("--zip_batch_size", type=int, default=4)
    parser.add_argument("--zip_num_epochs", type=int, default=3)
    parser.add_argument("--zip_use_lora", action="store_true")
    parser.add_argument("--zip_lora_ckpt", type=str, default="")
    parser.add_argument("--zip_save_steps", type=int, default=500,
                        help="Save checkpoint every X global steps (intra-epoch).")

    parser.add_argument("--zip_grad_accum_steps", type=int, default=1,
                        help="Gradient accumulation steps for ZIP training. ")

    # ZIP scoring: how to turn ZIP-bin logits into per-candidate scalar scores
    parser.add_argument("--zip_score_mode", type=str, default="logits_dot",
                        choices=["logits_dot"],
                        help="expected_bin: E[bin] under softmax(logits/temp); logits_dot: dot(logits, linspace(0,1)) (continuous).")
    parser.add_argument("--zip_score_group_norm", action="store_true",
                        help="Group-wise normalize candidate scores before softmax (preserves argmax ordering).")

    # Optional: margin-cap regularizer to avoid over-confident gaps on teacher-ambiguous groups
    parser.add_argument("--zip_margin_cap_lambda", type=float, default=0.0,
                        help="Weight for margin-cap regularizer. 0 disables it.")
    parser.add_argument("--zip_margin_cap_min", type=float, default=1.0,
                        help="Min allowed student score gap when teacher is ambiguous (cap).")
    parser.add_argument("--zip_margin_cap_max", type=float, default=8.0,
                        help="Max allowed student score gap when teacher is confident (cap).")
    parser.add_argument("--zip_margin_cap_tau", type=float, default=0.5,
                        help="Teacher margin scale for interpolating cap between min and max.")

    parser.add_argument("--zip_kl_lambda", type=float, default=0.1)
    parser.add_argument("--zip_listwise_loss", type=str, default="top1_ce",
                        choices=["kl", "pairwise_margin", "top1_ce"],
                        help="ZIP listwise objective over candidates: "
                             "kl (match teacher distribution), listmle (teacher order MLE), "
                             "pairwise_margin (top1 vs top2 margin), top1_ce (hard top-1 NLL).")

    # Weight for teacher-distribution KL term (can be used as the primary loss or as an auxiliary term)
    parser.add_argument("--zip_teacher_kl_weight", type=float, default=1.0,
                        help="Weight for KL(p_teacher || p_student) in the listwise objective. "
                             "If zip_listwise_loss==kl, this scales the KL loss. "
                             "If zip_listwise_loss!=kl, this acts as an auxiliary KL term (set small, e.g., 0.1).")

    # --- Teacher confidence filtering / weighting ---
    # Motivation: when the teacher itself is uncertain (top1 ~ top2), hard losses like top1_ce
    # can inject label noise. We can drop or down-weight those groups based on teacher margin.
    parser.add_argument(
        "--zip_teacher_filter_mode",
        type=str,
        default="none",
        choices=["none", "weight", "drop"],
        help="How to use teacher confidence (top1-top2 margin) in ZIP training. "
             "none: uniform weights. weight: weight each group by (margin+eps)^p. "
             "drop: zero-weight groups with margin < zip_teacher_min_margin.")
    parser.add_argument(
        "--zip_teacher_min_margin",
        type=float,
        default=0.0,
        help="Teacher top1-top2 prob margin threshold for filtering (only used when zip_teacher_filter_mode=drop).")
    parser.add_argument(
        "--zip_teacher_margin_power",
        type=float,
        default=1.0,
        help="Exponent p in group weight (margin+eps)^p for zip_teacher_filter_mode=weight/drop.")
    parser.add_argument(
        "--zip_teacher_weight_eps",
        type=float,
        default=1e-3,
        help="Small epsilon added to teacher margin before applying power weighting.")
    parser.add_argument("--zip_pairwise_margin", type=float, default=1.0,
                        help="Margin for pairwise_margin loss in candidate-logit space (after candidate temperature).")
    
    # Candidate-level temperature for listwise distribution over candidates (z-space)
    parser.add_argument("--zip_cand_softmax_temp", type=float, default=1.0,
                        help="Temperature for candidate-level logits/softmax in listwise loss (z-space). "
                             "Smaller => sharper distribution and stronger top-1 gradients. ")

    parser.add_argument("--zip_max_grad_norm", type=float, default=1.0)
    parser.add_argument("--zip_lora_r", type=int, default=16)
    parser.add_argument("--zip_lora_alpha", type=int, default=32)
    parser.add_argument("--zip_lora_dropout", type=float, default=0.05)

    parser.add_argument("--zip_debug", action="store_true",
                        help="Enable ZIP sanity checks + per-step ZIPRanking debug dump")
    parser.add_argument(
        "--zip_ddp",
        action="store_true",
        help="Enable DistributedDataParallel (multi-GPU). Normally auto-enabled when launched with torchrun.")
    parser.add_argument(
        "--zip_ref_placement",
        type=str,
        default="same",
        choices=["same", "paired", "cpu"],
        help="Where to place reference SLM in DDP. 'same': put ref on the same GPU as train (default). "
             "'paired': each rank uses 2 GPUs: train=cuda:(LOCAL_RANK*2), ref=cuda:(LOCAL_RANK*2+1). "
             "'cpu': put ref on CPU (slow).",
    )
    parser.add_argument("--zip_debug_max_records", type=int, default=20000,
                        help="Max debug records per run (avoid huge files)")
    parser.add_argument("--zip_debug_dump_candidates", action="store_true",
                        help="Dump per-candidate details (can be large)")

    parser.add_argument("--zip_train_k_keep", type=int, default=16,
                        help="Subsample candidates K during ZIP training to save memory. "
                             "Set <=0 to disable.")

    args = parser.parse_args()
    print("\n" + "="*50)
    print(f"[Main] Now lets start at {ts}")
    print(f"[EVALUATION] Starting dataset evaluation...{args.strategy_demand}")
    print(f"ablation_quantile {args.ablation_quantile}, slm_multi_k {args.slm_multi_k}, sc_k {args.sc_k}")
    print(f"flipped_target {args.flipped_target}")
    print(f"zip_listwise_loss {args.zip_listwise_loss}, zip_teacher_kl_weight {args.zip_teacher_kl_weight}")
    print(f"zip_kl_lambda {args.zip_kl_lambda}, zip_cand_softmax_temp {args.zip_cand_softmax_temp}")
    print(f"zip_teacher_filter_mode {args.zip_teacher_filter_mode}, zip_teacher_min_margin {args.zip_teacher_min_margin}, zip_teacher_margin_power {args.zip_teacher_margin_power}")

    args = setup_run_dir(args, ts)

    if not args.strategy_demand or args.strategy_demand in ["slm_only", "slm", "slm_multi"] or args.train_header:
        args.not_llm = True
    elif args.strategy_demand in ["llm_only", "llm", "llm_multi"]:
        args.not_slm = True
    if args.zip_label_mode == "topk_binary":
        args.zip_num_bins = 2

    print(f"not_slm {args.not_slm}, not_llm {args.not_llm}")
    set_seed(args.SEED)
    
    strategies = generate_strategy(args.strategy_demand, args.slm_multi_k, args.flipped_target, args.sc_k)
    for strategy in strategies:
        print(f"[Load Strategies] {strategy}") 

    if len(strategies):
        models, fmts, assets = load_models_and_assets(args)
        assert models is not None
        print("="*50)
        
        print("[Load Dataset] Loading data split...")
        train_data, test_data = load_data_for(args)

        print(f"Loaded {len(train_data)} (train) {len(test_data)} (test) examples from {args.dataset} test set.")
        if "aime" in args.dataset or "olympiadbench" in args.dataset or "mmlu" in args.dataset:
            MAX_NEW_TOKENS = 4096 
        else:
            MAX_NEW_TOKENS = 1024

        need_kl_calib = any("kl" in s.get("trigger_type","") for s in strategies)
        need_en_calib = any("entropy" in s.get("trigger_type","")  for s in strategies)

        if args.collect_zip_data or args.collect_header_data:
            data = train_data
        else:
            data = test_data
        
        need_header = any(s.get("use_header_trigger") for s in strategies)
        if need_header and "header_router" not in models:
            print("[Main] Loading Header Router for evaluation...")
            ckpt_path = args.header_ckpt
            if os.path.exists(ckpt_path):
                state = torch.load(ckpt_path, map_location=args.slm_device)
                slm = models["slm"]
                hidden_dim = state.get("hidden_dim", slm.config.hidden_size)
                emb_dim = state.get("emb_dim", slm.get_input_embeddings().weight.size(1))
                
                router = ACG_Router(hidden_dim, emb_dim)
                router.load_state_dict(state["model_state_dict"])
                router.to(
                    device=args.slm_device if torch.cuda.is_available() else "cpu", 
                    dtype=slm.dtype
                )
                
                router.eval()
                models["header_router"] = router
                print(f"[Header-Eval] Successfully loaded router from {ckpt_path}")
            else:
                print(f"Warning: Header checkpoint {ckpt_path} not found!")
                eval_strategies = []

        if need_kl_calib or need_en_calib:
            calib_n = min(args.ablation_calib_num, int(0.1 * len(data)))
            calib_data = data.select(range(calib_n))
            if "aime" not in args.dataset:
                data = data.select(range(calib_n, len(data)))
            
            print(f"[Calibration] Collecting on {calib_n} samples (quantile={args.ablation_quantile:.2f}) ...")
            calib_assets = run_calibration(args, models, fmts, assets, calib_data, 
                                            need_kl_calib, need_en_calib)
            assets.update(calib_assets)   

        for strategy in strategies:
            print(f"[Main] Collection on strategy {strategy} ...")
            run_evaluation(args, models, fmts, assets, strategy, data, start_idx=args.start_idx)

    if args.train_header:
        print("[Main] Start training Header Router...")
        models, fmts, assets = load_models_and_assets(args)
        assert models is not None
        print("="*50)

        train_header_router(args, models, fmts)

    if args.train_zip:
        print("[Main] Start training Header Router...")
        train_zip_distill(args)

    print("[Main] Experiment run complete.")