import os, sys, re, math, difflib, random, argparse, datetime, signal, traceback, faulthandler
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.distributed as dist
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

OUT_DIR = "runs_dp"
RUN_NAME = "unified_run"

def _rank_log_path(rank):
    os.makedirs(OUT_DIR, exist_ok=True)
    return os.path.join(os.path.abspath(OUT_DIR), f"{RUN_NAME}.rank{rank}.log")

def _log_exc(rank, where, extra=""):
    try:
        with open(_rank_log_path(rank), "a") as f:
            f.write(f"\n[{datetime.datetime.now()}] RANK {rank} EXCEPTION @ {where}\n")
            if extra: f.write(extra + "\n")
            f.write(traceback.format_exc() + "\n")
            f.flush()
    except Exception:
        pass

def _install_signal_dumps(rank):
    logf = open(_rank_log_path(rank), "a", buffering=1)
    faulthandler.enable(file=logf)
    def _dump(signum, frame):
        logf.write(f"\n[{datetime.datetime.now()}] RANK {rank} got signal {signum}, dumping stacks...\n")
        faulthandler.dump_traceback(file=logf, all_threads=True)
        logf.flush()
        raise SystemExit(1)
    signal.signal(signal.SIGTERM, _dump)
    signal.signal(signal.SIGINT, _dump)

# ==========================
#     DEFAULT MODELS (Qwen3)
# ==========================
DEFAULT_GEN_MODEL = "Qwen/Qwen3-1.7B"
DEFAULT_CLS_MODEL = "Qwen/Qwen3-1.7B"  # by default reuse the same model

# ==========================
#   PROMPTS (dataset-specific)
# ==========================

PROMPT_GPQA = (
    "You are answering a multiple-choice question.\n"
    "Options are labeled A, B, C, and D.\n"
    "Think step-by-step and show your reasoning.\n"
    "At the very end, output ONE line exactly in this format:\n"
    "Final Answer: \\boxed{A}\n"
    "where the letter is A, B, C, or D.\n"
)

# UNIFIED PROMPT for TIGER, AIME, MATH-500
PROMPT_STANDARD = (
    "Answer the following question step-by-step. "
    "At the very end, output exactly one line formatted as:\n"
    "Final Answer: \\boxed{...}\n"
)

# Tagging - KEEP 4 TAGS, add "unknown" only as fallback
TAG_LIST = [
    "final_answer",
    "setup_and_retrieval",
    "analysis_and_computation",
    "uncertainty_and_verification",
]

FEW_SHOT_PREFIX = """You are an expert in reasoning analysis. 
Classify the function of each sentence into one of the following tags:
1. final_answer
2. setup_and_retrieval
3. analysis_and_computation
4. uncertainty_and_verification
"""

LETTERS = ["A", "B", "C", "D"]
LETTER_RE = re.compile(r"(?i)(?:Final Answer\s*:\s*)?(?:\\boxed\{|\b)([A-D])(?:\}|\.|\b)")

# ==========================
#   DISTRIBUTED HELPERS
# ==========================
def ddp_init():
    """
    Initialize distributed process group with explicit device mapping.
    """
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    return rank, world_size, device, local_rank

# ==========================
#   DETERMINISM HELPERS (GUARANTEED REPRODUCIBILITY)
# ==========================
def set_global_seed(seed: int, rank: int = 0):
    """Set all random seeds for full reproducibility."""
    full_seed = (seed if seed is not None else 0) + int(rank)
    random.seed(full_seed)
    np.random.seed(full_seed)
    torch.manual_seed(full_seed)
    torch.cuda.manual_seed_all(full_seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["PYTHONHASHSEED"] = str(full_seed)
    torch.use_deterministic_algorithms(False)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return full_seed

def reseed_for_sample(base_seed: int, rank: int, global_idx: int):
    """
    Reseed before each sample generation for full reproducibility.
    Uses global_idx to ensure different samples get different seeds.
    """
    s = (base_seed if base_seed is not None else 0)
    s = (s * 1_000_003 + rank * 97_931 + global_idx) % (2**31 - 1)
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    return s

# ==========================
#   TEXT ACCURACY HELPERS
# ==========================
def _pick_answer(text: str) -> str:
    m = re.search(r'\\boxed\{(.+?)\}', text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    m = re.search(r'(?i)^\s*final\s*answer\s*:\s*(.+)$', text, flags=re.M)
    if m:
        return m.group(1).strip()
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    return lines[-1] if lines else text.strip()

def _to_number_maybe(s: str):
    s = s.strip().replace(',', '')
    if re.fullmatch(r'[+-]?\d+/\d+', s):
        a, b = s.split('/')
        try: return float(a) / float(b)
        except: return None
    if re.search(r'\d', s):
        try: return float(s)
        except: return None
    return None

def _norm_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r'\s+', ' ', s)
    return s

def compare_answers(pred_text: str, gold_text: str, atol: float = 1e-6) -> bool:
    pred = _pick_answer(pred_text)
    gold = _pick_answer(gold_text)
    pnum, gnum = _to_number_maybe(pred), _to_number_maybe(gold)
    if (pnum is not None) and (gnum is not None):
        return math.isclose(pnum, gnum, rel_tol=0.0, abs_tol=atol)
    return _norm_text(pred) == _norm_text(gold)

def _newline_token_kinds(tokenizer):
    nl1_ids, nl2_ids = set(), set()
    patterns_1 = ['\n', '\r\n']
    patterns_2 = ['\n\n', '\r\n\r\n']

    for pat in patterns_1:
        ids = tokenizer.encode(pat, add_special_tokens=False)
        if isinstance(ids, list) and len(ids) == 1:
            nl1_ids.add(ids[0])

    for pat in patterns_2:
        ids = tokenizer.encode(pat, add_special_tokens=False)
        if isinstance(ids, list) and len(ids) == 1:
            nl2_ids.add(ids[0])

    return nl1_ids, nl2_ids

def segment_by_newlines_2plus(token_ids, tokenizer):
    ids = token_ids.tolist() if isinstance(token_ids, torch.Tensor) else list(token_ids)
    n = len(ids)
    if n == 0:
        return []

    segments = []
    start_tok = 0
    i = 0
    consec_nl = 0

    def flush_before_current_token():
        nonlocal start_tok, i, consec_nl
        end_tok = i
        if end_tok > start_tok:
            segments.append((start_tok, end_tok))
        start_tok = i + 1
        consec_nl = 0

    while i < n:
        try:
            piece = tokenizer.decode([ids[i]], skip_special_tokens=True)
        except Exception:
            piece = ""

        piece = piece.replace("\r\n", "\n")

        for ch in piece:
            if ch == "\n":
                consec_nl += 1
                if consec_nl >= 2:
                    flush_before_current_token()
                    consec_nl = 0
            else:
                consec_nl = 0

        i += 1

    if start_tok < n:
        segments.append((start_tok, n))

    return segments if segments else [(0, n)]


def decode_token_span(token_ids: torch.Tensor, start: int, end: int, tokenizer) -> str:
    """Safely decode a span of tokens."""
    if start >= end or start < 0:
        return ""
    total_len = len(token_ids) if isinstance(token_ids, torch.Tensor) else len(list(token_ids))
    if end > total_len:
        end = total_len
    if start >= end:
        return ""
    span_ids = token_ids[start:end]
    if isinstance(span_ids, torch.Tensor):
        span_ids = span_ids.tolist()
    try:
        text = tokenizer.decode(span_ids, skip_special_tokens=True)
        return text.strip()
    except:
        return ""

def qwen_build_gen_prompt(tokenizer, user_text: str) -> str:
    """
    Build chat prompt using Qwen chat template for GENERATION.
    Uses enable_thinking=True for generation tasks.
    """
    messages = [{"role": "user", "content": user_text}]
    
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        raise RuntimeError(
            f"Tokenizer {tokenizer.name_or_path} does not have a chat_template! "
            f"This is required for Qwen models."
        )
    
    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True  # Qwen-specific for generation
        )
        return text
    except Exception as e:
        raise RuntimeError(
            f"Failed to apply chat template for {tokenizer.name_or_path}: {e}"
        )

def qwen_build_cls_prompt(tokenizer, user_text: str) -> str:
    """
    Build chat prompt using Qwen chat template for CLASSIFICATION.
    Uses enable_thinking=False for classification tasks.
    """
    messages = [{"role": "user", "content": user_text}]
    
    if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
        raise RuntimeError(
            f"Classifier tokenizer {tokenizer.name_or_path} does not have a chat_template! "
            f"This is required for Qwen models."
        )
    
    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False  # Qwen-specific for classification
        )
        return text
    except Exception as e:
        raise RuntimeError(
            f"Failed to apply classifier chat template: {e}"
        )

# ==========================
#        CLASSIFIER
# ==========================
def classify_step_with_unknown(cls_tok, cls_model, text: str, cutoff: float, device) -> str:

    user_prompt = (
        f"{FEW_SHOT_PREFIX}\n\n"
        f"Now classify the following sentence with ONE label only.\n"
        f"Sentence: \"{text}\"\n"
        f"Choose one label from: {', '.join(TAG_LIST)}\n"
        f"Label:"
    )

    formatted_prompt = qwen_build_cls_prompt(cls_tok, user_prompt)

    inputs = cls_tok(formatted_prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        generated_ids = cls_model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            eos_token_id=cls_tok.eos_token_id,
            pad_token_id=(cls_tok.pad_token_id if cls_tok.pad_token_id is not None else cls_tok.eos_token_id),
        )

    input_len = inputs["input_ids"].shape[-1]
    generated_ids = generated_ids[:, input_len:]
    gen_txt = cls_tok.decode(generated_ids[0], skip_special_tokens=True)

    label_line = gen_txt.strip().split("\n")[0].lower().replace("label:", "").strip()
    match = difflib.get_close_matches(label_line, TAG_LIST, n=1, cutoff=cutoff)

    if match:
        return match[0]
    else:
        return "unknown"

# ==========================
#          MAIN
# ==========================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task",
        choices=["tiger", "aime24", "math500", "gpqa_diamond"],
        required=True)
    parser.add_argument("--gen_model", default=DEFAULT_GEN_MODEL)
    parser.add_argument("--cls_model", default=DEFAULT_CLS_MODEL)
    parser.add_argument("--out_dir", default="runs_dp")
    parser.add_argument("--run_name", default=None)
    parser.add_argument("--num_samples", type=int, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=4000)
    # Qwen3 generation parameters from code1
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=20)
    parser.add_argument("--min_p", type=float, default=0)
    parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
    parser.add_argument("--classify_steps", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--shuffle_dataset", action="store_true",
                    help="Shuffle dataset order deterministically with --seed.")
    parser.add_argument("--tiger_answer_types", nargs="*", 
                       default=["Float","Multiple Choice","Integer","Percentage"])
    parser.add_argument("--tiger_difficulties", nargs="*", 
                       default=["Senior High School","Junior High School","Primary School"])
    parser.add_argument("--classifier_cutoff", type=float, default=0.6,
                       help="Confidence threshold for classifier (lower = more unknowns)")
    args = parser.parse_args()

    global OUT_DIR, RUN_NAME
    OUT_DIR = args.out_dir
    RUN_NAME = args.run_name or f"{args.task}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"

    # ===== DDP init & per-rank signal dump hooks =====
    rank, world, device, local_rank = ddp_init()
    _install_signal_dumps(rank)
    if rank == 0:
        os.makedirs(OUT_DIR, exist_ok=True)

    # ===== Seeds / determinism (GUARANTEED REPRODUCIBILITY) =====
    full_seed = set_global_seed(args.seed, rank)
    gpqa_seed = args.seed

    if rank == 0:
        print(f"[INFO] Full reproducibility enabled with seed={args.seed}, rank_seed={full_seed}")

    # ===== Perf toggles =====
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # ===== Models on THIS rank =====
    dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
    DTYPE = dtype_map[args.dtype]

    if rank == 0:
        print(f"[INFO] Loading generation model: {args.gen_model}")

    # Load Qwen tokenizer with trust_remote_code
    gen_tok = AutoTokenizer.from_pretrained(args.gen_model, trust_remote_code=True)
    if gen_tok.pad_token is None and gen_tok.eos_token is not None:
        gen_tok.pad_token = gen_tok.eos_token

    # Load Qwen model with trust_remote_code and flash_attention_2
    gen_model = AutoModelForCausalLM.from_pretrained(
        args.gen_model, 
        torch_dtype=DTYPE,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        device_map={"": local_rank}
    ).eval()

    if not hasattr(gen_tok, 'chat_template') or gen_tok.chat_template is None:
        raise RuntimeError(
            f"Generation model {args.gen_model} does not have a chat_template! "
            f"This is required for Qwen models."
        )

    if rank == 0:
        print(f"[INFO] Generation model loaded successfully with Qwen chat template")

    cls_tok = cls_model = None
    CLS_PAD_ID = None
    if args.classify_steps:
        if rank == 0:
            print(f"[INFO] Loading classifier model: {args.cls_model}")

        cls_tok = AutoTokenizer.from_pretrained(args.cls_model, trust_remote_code=True)
        if cls_tok.pad_token is None and cls_tok.eos_token is not None:
            cls_tok.pad_token = cls_tok.eos_token

        cls_model = AutoModelForCausalLM.from_pretrained(
            args.cls_model, 
            torch_dtype=DTYPE, 
            attn_implementation="flash_attention_2",
            trust_remote_code=True, 
            device_map={"": local_rank}
        ).eval()

        if not hasattr(cls_tok, 'chat_template') or cls_tok.chat_template is None:
            raise RuntimeError(
                f"Classifier model {args.cls_model} does not have a chat_template! "
                f"This is required for Qwen models."
            )

        try:
            ccfg = cls_model.generation_config
            ccfg.do_sample = False
            ccfg.top_p = None
            ccfg.top_k = None
            ccfg.temperature = None
        except Exception:
            pass

        if rank == 0:
            print(f"[INFO] Classifier model loaded successfully with Qwen chat template")
        CLS_PAD_ID = cls_tok.pad_token_id if cls_tok.pad_token_id is not None else cls_tok.eos_token_id

    GEN_PAD_ID = gen_tok.pad_token_id if gen_tok.pad_token_id is not None else gen_tok.eos_token_id

    records, num_correct = [], 0

    def _last_nonempty_line(s: str) -> str:
        for line in reversed(s.splitlines()):
            line = line.strip()
            if line:
                return line
        return ""

    def _pick_letter(text: str):
        m = LETTER_RE.search(text)
        return m.group(1).upper() if m else None

    def generate_and_collect(prompt: str, question: str, gold_answer: str,
                            label_note: str, is_mc: bool = False,
                            extra_mc=None, global_sample_idx: int = 0):
        nonlocal num_correct

        def build_last_token_hidden_3d(hidden_states_out):
            per_step = []
            for step_idx, per_layer in enumerate(hidden_states_out):
                step_last = []
                for h in per_layer:
                    if h.dim() == 3:
                        v = h[:, -1, :].squeeze(0)  # [D]
                    elif h.dim() == 2:
                        v = h[-1, :]                 # [D]
                    else:
                        raise RuntimeError(f"Unexpected hidden state shape {tuple(h.shape)} at step {step_idx}")
                    step_last.append(v)
                step_tensor = torch.stack(step_last, dim=0)  # [L+1, D]
                per_step.append(step_tensor)
            # [T_gen, L+1, D] -> [L+1, T_gen, D]
            hs_3d = torch.stack(per_step, dim=0).transpose(0, 1).contiguous()
            return hs_3d

        reseed_for_sample(args.seed, rank, global_sample_idx)

        # Use Qwen chat template for generation (enable_thinking=True)
        formatted_prompt = qwen_build_gen_prompt(gen_tok, prompt)

        gen_inputs = gen_tok(formatted_prompt, return_tensors="pt").to(device)
        assert gen_inputs["input_ids"].shape[0] == 1, f"Batch size must be 1, got {gen_inputs['input_ids'].shape[0]}"

        with torch.no_grad():
            out = gen_model.generate(
                **gen_inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=True,
                temperature=args.temperature,
                top_p=args.top_p,
                top_k=args.top_k,
                min_p=args.min_p,
                eos_token_id=gen_tok.eos_token_id,
                pad_token_id=(gen_tok.pad_token_id if gen_tok.pad_token_id is not None else gen_tok.eos_token_id),
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

        seq = out.sequences[0]                      # [T_in + T_gen]
        inp_len = gen_inputs["input_ids"].shape[1]
        gen_only_ids = seq[inp_len:]                # [T_gen]
        assert gen_only_ids.dim() == 1, f"gen_only_ids should be 1D, got shape {gen_only_ids.shape}"

        generated_text = gen_tok.decode(gen_only_ids, skip_special_tokens=True)

        hs_out = out.hidden_states
        hidden_states = build_last_token_hidden_3d(hs_out).cpu()  # [L+1, T_gen, D]

        assert hidden_states.shape[1] == len(gen_only_ids), \
            f"hidden_states tokens {hidden_states.shape[1]} != gen_only_ids {len(gen_only_ids)}"
        if isinstance(hs_out, (list, tuple)):
            assert len(hs_out) == len(gen_only_ids), f"steps {len(hs_out)} != gen_tokens {len(gen_only_ids)}"

        token_segments = segment_by_newlines_2plus(gen_only_ids, gen_tok)
        if not token_segments:
            token_segments = [(0, len(gen_only_ids))]

        step_hidden_states = []      # List[Tensor [L+1, D] or None]
        sentences_with_labels = []   # List[(segment_text, tag)]
        MAX_CLASSIFY_CHARS = 32768

        def _last_nonempty_line(s: str) -> str:
            for line in reversed(s.splitlines()):
                line = line.strip()
                if line:
                    return line
            return ""

        def _pick_letter(text: str):
            m = LETTER_RE.search(text)
            return m.group(1).upper() if m else None

        for (t0, t1) in token_segments:
            if t1 <= t0 or t0 < 0 or t1 > hidden_states.shape[1]:
                continue

            segment_text = decode_token_span(gen_only_ids, t0, t1, gen_tok)
            if not segment_text or not segment_text.strip():
                continue

            step_hidden = hidden_states[:, t0, :].clone()  # [L+1, D]
            step_hidden_states.append(step_hidden.cpu())

            if args.classify_steps and (cls_tok is not None) and (cls_model is not None):
                try:
                    tag = classify_step_with_unknown(
                        cls_tok, cls_model, segment_text[:MAX_CLASSIFY_CHARS],
                        cutoff=args.classifier_cutoff, device=device
                    )
                except Exception:
                    tag = "unknown"
            else:
                tag = "unknown"

            sentences_with_labels.append((segment_text, tag))

        del hidden_states
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        if len(step_hidden_states) != len(sentences_with_labels):
            raise ValueError(
                f"Step/label mismatch: len(steps)={len(step_hidden_states)} "
                f"len(labels)={len(sentences_with_labels)} "
                f"global_idx={global_sample_idx}"
            )

        if not is_mc:
            is_correct = compare_answers(generated_text, gold_answer) if gold_answer else False
            pred_letter = None
        else:
            final_line = _last_nonempty_line(generated_text)
            pred_letter = _pick_letter(final_line)
            gold_letter = extra_mc.get("correct_letter") if extra_mc else None
            is_correct = (pred_letter == gold_letter) if (pred_letter and gold_letter) else False

        rec = {
            "global_sample_idx": global_sample_idx,
            "label_note": label_note,
            "prompt": prompt,
            "formatted_prompt": formatted_prompt,
            "generated_text": generated_text,
            "question": question,
            "ground_truth_answer": gold_answer,
            "is_correct": is_correct,
            "sentences_with_labels": sentences_with_labels,
            "step_hidden_states": step_hidden_states,
            "gen_token_count": int(gen_only_ids.numel()),
            "tag_list": TAG_LIST + ["unknown"],
            "num_steps": len(step_hidden_states),
        }

        if is_mc and extra_mc:
            if "options_shuffled" in extra_mc:
                rec["options_shuffled"] = extra_mc["options_shuffled"]
            if "correct_letter" in extra_mc:
                rec["ground_truth_answer"] = extra_mc["correct_letter"]
            rec["predicted_letter"] = pred_letter

        rec["predicted_boxed"] = re.findall(r'\\boxed\{(.+?)\}', generated_text, flags=re.DOTALL)[:1]

        records.append(rec)
        num_correct += int(is_correct)

    # ===== Load + run per task =====
    try:
        task = args.task

        if task == "tiger":
            ds_full = load_dataset("TIGER-Lab/WebInstruct-verified", split="test")
            ats = set(args.tiger_answer_types)
            diffs = set(args.tiger_difficulties)
            ANSWER_TYPE_KEY = "answer_type"

            if ANSWER_TYPE_KEY in ds_full.column_names and "difficulty" in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x[ANSWER_TYPE_KEY] in ats and x["difficulty"] in diffs)
            elif ANSWER_TYPE_KEY in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x[ANSWER_TYPE_KEY] in ats)
            elif "difficulty" in ds_full.column_names:
                ds_full = ds_full.filter(lambda x: x["difficulty"] in diffs)

            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)

            if rank == 0:
                print(f"[INFO] Tiger filtered dataset: {total} samples, taking {take}")

            ds_subset = ds_full.select(range(take))

            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)

            ds = ds_subset.shard(num_shards=world, index=rank)

            if rank == 0:
                print(f"[INFO] Rank {rank} processing {len(ds)} samples")

            for local_idx, sample in enumerate(tqdm(ds, disable=(rank != 0), desc=f"Rank {rank}")):
                try:
                    global_idx = rank + local_idx * world

                    question = str(sample["question"]).strip()
                    gold_answer = str(sample["answer"]).strip()
                    prompt = f"{PROMPT_STANDARD}\n{question}\n"

                    generate_and_collect(
                        prompt, question, gold_answer,
                        label_note="tiger_filtered", 
                        is_mc=False, 
                        global_sample_idx=global_idx
                    )
                except Exception as e:
                    _log_exc(rank, "tiger-per-sample", f"local_idx={local_idx}, global_idx={rank + local_idx * world}")
                    records.append({
                        "error": f"{type(e).__name__}: {e}", 
                        "local_idx": local_idx,
                        "global_sample_idx": rank + local_idx * world
                    })

        elif task == "aime24":
            ds_full = load_dataset("HuggingFaceH4/aime_2024", split="train")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)

            if rank == 0:
                print(f"[INFO] AIME24 dataset: {total} samples, taking {take}")

            ds_subset = ds_full.select(range(take))

            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)

            ds = ds_subset.shard(num_shards=world, index=rank)

            if rank == 0:
                print(f"[INFO] Rank {rank} processing {len(ds)} samples")

            for local_idx, sample in enumerate(tqdm(ds, disable=(rank != 0), desc=f"Rank {rank}")):
                try:
                    global_idx = rank + local_idx * world

                    question = str(sample["problem"]).strip()
                    gold_answer = str(sample["answer"]).strip()
                    prompt = f"{PROMPT_STANDARD}\n{question}\n"

                    generate_and_collect(
                        prompt, question, gold_answer,
                        label_note="aime24", 
                        is_mc=False, 
                        global_sample_idx=global_idx
                    )
                except Exception as e:
                    _log_exc(rank, "aime24-per-sample", f"local_idx={local_idx}, global_idx={rank + local_idx * world}")
                    records.append({
                        "error": f"{type(e).__name__}: {e}", 
                        "local_idx": local_idx,
                        "global_sample_idx": rank + local_idx * world
                    })

        elif task == "math500":
            ds_full = load_dataset("HuggingFaceH4/MATH-500", split="test")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)

            if rank == 0:
                print(f"[INFO] MATH-500 dataset: {total} samples, taking {take}")

            ds_subset = ds_full.select(range(take))

            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)

            ds = ds_subset.shard(num_shards=world, index=rank)

            if rank == 0:
                print(f"[INFO] Rank {rank} processing {len(ds)} samples")

            for local_idx, sample in enumerate(tqdm(ds, disable=(rank != 0), desc=f"Rank {rank}")):
                try:
                    global_idx = rank + local_idx * world

                    question = str(sample["problem"]).strip()
                    gold_answer = str(sample["solution"]).strip()
                    prompt = f"{PROMPT_STANDARD}\n{question}\n"

                    generate_and_collect(
                        prompt, question, gold_answer,
                        label_note="math500", 
                        is_mc=False, 
                        global_sample_idx=global_idx
                    )
                except Exception as e:
                    _log_exc(rank, "math500-per-sample", f"local_idx={local_idx}, global_idx={rank + local_idx * world}")
                    records.append({
                        "error": f"{type(e).__name__}: {e}", 
                        "local_idx": local_idx,
                        "global_sample_idx": rank + local_idx * world
                    })

        elif task == "gpqa_diamond":
            ds_full = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
            total = len(ds_full)
            take = min(args.num_samples if args.num_samples is not None else total, total)

            if rank == 0:
                print(f"[INFO] GPQA Diamond dataset: {total} samples, taking {take}")

            ds_subset = ds_full.select(range(take))

            if args.shuffle_dataset:
                ds_subset = ds_subset.shuffle(seed=args.seed)

            ds = ds_subset.shard(num_shards=world, index=rank)

            if rank == 0:
                print(f"[INFO] Rank {rank} processing {len(ds)} samples")

            for local_idx, sample in enumerate(tqdm(ds, disable=(rank != 0), desc=f"Rank {rank}")):
                try:
                    global_idx = rank + local_idx * world

                    rng = random.Random((gpqa_seed if gpqa_seed is not None else 0) + global_idx)

                    needed = ["Question","Correct Answer","Incorrect Answer 1",
                             "Incorrect Answer 2","Incorrect Answer 3"]
                    for k in needed:
                        if k not in sample:
                            raise ValueError(f"Missing column '{k}'")

                    q = str(sample["Question"]).strip()
                    opts = [
                        str(sample["Correct Answer"]).strip(),
                        str(sample["Incorrect Answer 1"]).strip(),
                        str(sample["Incorrect Answer 2"]).strip(),
                        str(sample["Incorrect Answer 3"]).strip(),
                    ]

                    idxs = [0, 1, 2, 3]
                    rng.shuffle(idxs)
                    shuf = [opts[i] for i in idxs]
                    correct_idx = idxs.index(0)
                    correct_letter = LETTERS[correct_idx]

                    options_block = "\n".join(f"{LETTERS[i]}. {shuf[i]}" for i in range(4))
                    prompt = f"{PROMPT_GPQA}\n{q}\n\n{options_block}\n"

                    extra = {
                        "options_shuffled": {LETTERS[i]: shuf[i] for i in range(4)},
                        "correct_letter": correct_letter
                    }

                    generate_and_collect(
                        prompt, q, correct_letter,
                        label_note="gpqa_diamond", 
                        is_mc=True, 
                        extra_mc=extra, 
                        global_sample_idx=global_idx
                    )
                except Exception as e:
                    _log_exc(rank, "gpqa-per-sample", f"local_idx={local_idx}, global_idx={rank + local_idx * world}")
                    records.append({
                        "error": f"{type(e).__name__}: {e}", 
                        "local_idx": local_idx,
                        "global_sample_idx": rank + local_idx * world
                    })

        # ===== Save shard =====
        os.makedirs(OUT_DIR, exist_ok=True)
        shard_path = os.path.join(OUT_DIR, f"{RUN_NAME}.rank{rank}.pt")

        save_data = {
            "records": records, 
            "accuracy_partial": (num_correct, len(records)),
            "tag_list": TAG_LIST + ["unknown"],
            "config": vars(args),
            "seed_info": {
                "base_seed": args.seed,
                "rank_seed": full_seed,
                "rank": rank,
                "world_size": world,
            }
        }

        torch.save(save_data, shard_path)

        if rank == 0:
            print(f"\n{'='*60}")
            print(f"[Rank 0] Completed {len(records)} samples")
            if len(records) > 0:
                acc_pct = num_correct / len(records) * 100
                print(f"[Rank 0] Partial accuracy: {num_correct}/{len(records)} = {acc_pct:.2f}%")
            print(f"[Rank 0] Shard saved to: {shard_path}")
            print(f"{'='*60}\n")
        else:
            print(f"[Rank {rank}] Completed {len(records)} samples, saved to {shard_path}")

    except Exception as e:
        r = dist.get_rank() if dist.is_initialized() else 0
        _log_exc(r, "main-fatal", f"Fatal error: {type(e).__name__}: {e}")
        raise
    finally:
        if dist.is_initialized():
            try:
                dist.destroy_process_group()
            except Exception:
                pass

        try:
            local_rank_env = int(os.environ.get("LOCAL_RANK", 0))
            if local_rank_env == 0:
                print(f"\n[Rank 0] All shards/logs saved to {OUT_DIR}")
                print(f"[Rank 0] Merge shards with: python merge_shards.py --run_name {RUN_NAME} --out_dir {OUT_DIR}")
        except Exception:
            pass

if __name__ == "__main__":
    main()