import os
import json
import re
import numpy as np
import torch
import gc
import time
from tqdm import tqdm
from torch.nn.attention import SDPBackend, sdpa_kernel

# backends = [SDPBackend.FLASH_ATTENTION]
# if True: with sdpa_kernel(backends=backends):

template_rag = open('run/pipelines/benchmarks/utils/config/0shot_rag.txt', encoding='utf-8').read()
template_no_context = open('run/pipelines/benchmarks/utils/config/0shot_no_context.txt', encoding='utf-8').read()
template_0shot = open('run/pipelines/benchmarks/utils/config/0shot.txt', encoding='utf-8').read()
template_0shot_cot = open('run/pipelines/benchmarks/utils/config/0shot_cot.txt', encoding='utf-8').read()
template_0shot_cot_ans = open('run/pipelines/benchmarks/utils/config/0shot_cot_ans.txt', encoding='utf-8').read()

from .utils import (
    build_chat_longbench,
    build_chat_pg19,
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
    score_longgenbench_single,
    build_input_ids,
    extract_longbenchv2_answer,
)

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench_p": code_sim_score,
}

def _get_batch_size(generator, default=1) -> int:
    bs = getattr(generator, "batch_size", None)
    try:
        bs = int(bs) if bs is not None else default
    except Exception:
        bs = default
    return max(bs, 1)

def _iter_batches(dataset, batch_size: int):
    for i in range(0, len(dataset), batch_size):
        yield i, dataset[i : i + batch_size]

def _get_pad_id(tokenizer) -> int:
    # Prefer real PAD; fall back to EOS if needed
    if tokenizer.pad_token_id is not None:
        return tokenizer.pad_token_id
    else:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer.eos_token_id

def _batch_tokenize_chat(tokenizer, prompts, device):
    """
    prompts: List[str]
    returns:
      input_ids: (B, S) padded
    """
    batch_messages = [[{"role": "user", "content": p}] for p in prompts]
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    input_ids = tokenizer.apply_chat_template(
        batch_messages,
        tokenize=True,
        add_generation_prompt=True,
        padding=True,
        return_tensors="pt",
    ).to(device)

    return input_ids

def run_gsm8k_eval(generator, tokenizer, past_key_values, draft_past_key_values, args, dataset, log_dir):
    # 1) Warm-up (keep as-is, or also batch it; not required)
    bs = _get_batch_size(generator, default=1)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in range(args.warmup_iter):
        warmup_prompt = "Solve this math problem. Give the reasoning steps ...\nWhat is 1 + 1?"
        tokenizer.use_default_system_prompt = True
        warmup_ids = tokenizer.apply_chat_template(
            [[{"role": "user", "content": warmup_prompt}]]*bs,
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)

        generator.generate(
            warmup_ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2) Main loop (batched)
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "0.jsonl")

    tput_list, tacc_list, draft_times, target_times = [], [], [], []
    total_q, correct_q = 0, 0
    int_regex = re.compile(r"[-+]?\d+")


    for base_idx, batch in tqdm(list(_iter_batches(dataset, bs)), desc="Evaluating GSM8K (batched)"):
        # Build batched prompts
        prompts = [x["question"] for x in batch]
        gts = [x["answer"] for x in batch]

        tokenizer.use_default_system_prompt = True
        input_ids = _batch_tokenize_chat(tokenizer, prompts, generator.device)
        # Run one batched generate
        output_ids = generator.generate(
            input_ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values
        )

        # Reset KV once per batch
        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        # Pull batch-level perf log once (applies to this whole generate call)
        batch_log = {**getattr(generator, "exp_log", {})}
        peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

        # Decode & score per kept sample
        for local_i in range(len(batch)):
            out_seq = output_ids[local_i][input_ids.size(1):]

            gen_part = out_seq
            output_str = tokenizer.decode(gen_part, skip_special_tokens=True).strip()

            gt_text = gts[local_i].strip()

            # correctness (same as your original)
            lines = output_str.splitlines()
            last_line = lines[-1] if lines else output_str
            m_out = int_regex.search(last_line)
            pred_int = m_out.group(0).lstrip("+").lstrip("0") or "0" if m_out else None

            gt_lines = gt_text.splitlines()
            last_gt = gt_lines[-1]
            m_gt = int_regex.search(last_gt)
            gt_int = m_gt.group(0).lstrip("+").lstrip("0") or "0" if m_gt else None

            is_correct = (pred_int is not None and gt_int is not None and pred_int == gt_int)
            total_q += 1
            correct_q += int(is_correct)

            # Record per sample
            record = dict(batch_log)
            record.update({
                "query": prompts[local_i],
                "response": output_str,
                "answer": gt_text,
                "Accuracy": int(is_correct),
                "peak_memory": float(peak_memory),
                "batch_size": int(input_ids.size(0)),
            })

            if record.get("tput") is not None:
                tput_list.append(record.get("tput", 0))
            if record.get("avg_sampled") is not None:
                tacc_list.append(record.get("avg_sampled", 0))
            if record.get("avg_draft_time") is not None:
                draft_times.append(record.get("avg_draft_time", 0))
            if record.get("avg_target_time") is not None:
                target_times.append(record.get("avg_target_time", 0))

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

        # cleanup
        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3) Aggregate
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0, 0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0, 0)
    answer_accuracy = correct_q / total_q if total_q > 0 else 0
    avg_draft = np.mean(draft_times) if draft_times else 0
    avg_target = np.mean(target_times) if target_times else 0
    peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

    print("Final GSM8K Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAnswer Accuracy  : {answer_accuracy:.3f} ({correct_q}/{total_q})")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_memory:.3f} GiB")

    return (tput_mean, tput_std, tacc_mean, tacc_std, answer_accuracy, avg_draft, avg_target, peak_memory)

def run_aime_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
):
    """
    Multi-batch AIME-2024 evaluation.

    Returns:
        (tput_mean, tput_std, tacc_mean, tacc_std,
         answer_accuracy, avg_draft_time, avg_target_time, peak_memory)
    """

    # 1. Warm-up (unchanged)
    bs = _get_batch_size(generator, default=1)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in range(args.warmup_iter):
        warmup_prompt = "Solve this math problem. Give the reasoning steps ...\nWhat is 1 + 1?"
        tokenizer.use_default_system_prompt = True
        warmup_ids = tokenizer.apply_chat_template(
            [[{"role": "user", "content": warmup_prompt}]]*bs,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(generator.device)

        generator.generate(
            warmup_ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2. Main loop (batched)
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "aime_eval.jsonl")

    tput_list, tacc_list = [], []
    draft_times, target_times = [], []
    total_q, correct_q = 0, 0
    int_regex = re.compile(r"[-+]?\d+")


    for base_idx, batch in tqdm(
        list(_iter_batches(dataset, bs)),
        desc="Evaluating AIME (batched)",
    ):
        prompts = [x["question"] for x in batch]
        gts = [x["answer"].strip() for x in batch]

        tokenizer.use_default_system_prompt = True
        input_ids = _batch_tokenize_chat(tokenizer, prompts, generator.device)

        # If you want to keep the original "skip too long" behavior,
        # you can still do a coarse check on the padded width:
        if input_ids.size(1) > args.max_length:
            # NOTE: this differs from per-sample filtering; matches your reference GSM code style.
            del input_ids
            gc.collect()
            torch.cuda.empty_cache()
            continue

        output_ids = generator.generate(
            input_ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        batch_log = {**getattr(generator, "exp_log", {})}
        peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024**3)

        prompt_len = input_ids.size(1)  # padded prompt length (matches your reference GSM code)

        for b in range(len(batch)):
            response = tokenizer.decode(
                output_ids[b, prompt_len:], skip_special_tokens=True
            ).strip()
            ground_truth = gts[b]

            record = dict(batch_log)
            record.update(
                {
                    "query": prompts[b],
                    "response": response,
                    "answer": ground_truth,
                    "peak_memory": float(peak_memory),
                    "batch_size": int(input_ids.size(0)),
                }
            )

            # Extract integers
            resp_last = response.splitlines()[-1] if response.splitlines() else response
            gt_last = ground_truth.splitlines()[-1] if ground_truth.splitlines() else ground_truth

            pred_match = int_regex.search(resp_last)
            gt_match = int_regex.search(gt_last)

            pred_int = (
                pred_match.group(0).lstrip("+").lstrip("0") or "0"
                if pred_match
                else None
            )
            gt_int = (
                gt_match.group(0).lstrip("+").lstrip("0") or "0"
                if gt_match
                else None
            )

            is_correct = (pred_int is not None and gt_int is not None and pred_int == gt_int)
            total_q += 1
            correct_q += int(is_correct)
            record["Accuracy"] = int(is_correct)

            # Aggregate perf metrics (batch-level exp_log copied per sample)
            if record.get("tput") is not None:
                tput_list.append(record["tput"])
            if record.get("avg_sampled") is not None:
                tacc_list.append(record["avg_sampled"])
            if record.get("avg_draft_time") is not None:
                draft_times.append(record["avg_draft_time"])
            if record.get("avg_target_time") is not None:
                target_times.append(record["avg_target_time"])

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3. Aggregate overall
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0, 0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0, 0)
    accuracy = correct_q / total_q if total_q else 0
    avg_draft = np.mean(draft_times) if draft_times else 0
    avg_target = np.mean(target_times) if target_times else 0
    peak_mem = torch.cuda.max_memory_reserved(generator.device) / (1024**3)

    print("Final AIME Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAnswer Accuracy  : {accuracy:.3f} ({correct_q}/{total_q})")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_mem:.3f} GiB")

    return (
        tput_mean,
        tput_std,
        tacc_mean,
        tacc_std,
        accuracy,
        avg_draft,
        avg_target,
        peak_mem,
    )

# WARNING: This function is NOT ready
def run_mmlu_pro_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
):
    # 1) Warmup (unchanged)
    bs = _get_batch_size(generator, default=1)
    orig_prof = generator.profiling
    generator.profiling = False

    warmup = "What is 1 + 1?"
    warmup_prompt = (
        f"{warmup}\n\nA. 0\nB. 1\nC. 2\nD. 3\nE. 4\nF. 5\nG. 6\nH. 7\nI. 8\nJ. 9\n\nAnswer:"
    )

    for _ in range(args.warmup_iter):
        tokenizer.use_default_system_prompt = True
        ids = tokenizer.apply_chat_template(
            [[{"role": "user", "content": warmup_prompt}]]*bs,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(generator.device)

        generator.generate(
            ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

    generator.profiling = orig_prof

    # 2) Main loop (batched)
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "mmlu_pro.jsonl")

    letter_re = re.compile(r"\b([A-J])\b")
    tput_list, tacc_list = [], []
    draft_times, target_times = [], []
    total_q, correct_q = 0, 0


    for base_idx, batch in tqdm(
        list(_iter_batches(dataset, bs)),
        desc="Eval MMLU-Pro (batched)",
    ):
        prompts = [x["question"] for x in batch]
        gts = [x["answer"] for x in batch]

        tokenizer.use_default_system_prompt = True
        input_ids = _batch_tokenize_chat(tokenizer, prompts, generator.device)

        # Keep the same coarse "skip too long" behavior as your reference GSM batching
        if input_ids.size(1) > args.max_length:
            del input_ids
            gc.collect()
            torch.cuda.empty_cache()
            continue

        output_ids = generator.generate(
            input_ids,
            temperature=args.temperature,
            max_length=args.max_length,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        batch_log = {**getattr(generator, "exp_log", {})}
        peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024**3)

        prompt_len = input_ids.size(1)

        for b in range(len(batch)):
            resp = tokenizer.decode(
                output_ids[b, prompt_len:],
                skip_special_tokens=True,
            ).strip()

            # pick last non-empty line
            last_line = next((l for l in reversed(resp.splitlines()) if l.strip()), resp)
            m = letter_re.search(last_line)
            pred = m.group(1) if m else None

            gt = gts[b]
            is_correct = (pred == gt)
            total_q += 1
            correct_q += int(is_correct)

            record = dict(batch_log)
            record.update(
                {
                    "query": prompts[b],
                    "response": resp,
                    "answer": gt,
                    "pred": pred,
                    "Accuracy": int(is_correct),
                    "peak_memory": float(peak_memory),
                    "batch_size": int(input_ids.size(0)),
                }
            )

            if record.get("tput") is not None:
                tput_list.append(record["tput"])
            if record.get("avg_sampled") is not None:
                tacc_list.append(record["avg_sampled"])
            if record.get("avg_draft_time") is not None:
                draft_times.append(record["avg_draft_time"])
            if record.get("avg_target_time") is not None:
                target_times.append(record["avg_target_time"])

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3) Aggregate
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0, 0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0, 0)
    accuracy = correct_q / total_q if total_q else 0
    avg_draft = np.mean(draft_times) if draft_times else 0
    avg_target = np.mean(target_times) if target_times else 0
    peak_mem = torch.cuda.max_memory_reserved(generator.device) / (1024**3)

    print("Final MMLU-Pro Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAnswer Accuracy  : {accuracy:.3f} ({correct_q}/{total_q})")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_mem:.3f} GiB")

    return (
        tput_mean,
        tput_std,
        tacc_mean,
        tacc_std,
        accuracy,
        avg_draft,
        avg_target,
        peak_mem,
    )

import re
import json
import base64
import zlib
import pickle
import subprocess
import os
import tempfile
from typing import Any, List, Dict

# --- Utility functions consolidated from lcb_runner ---

def _extract_code(text: str) -> str:
    """Extracts code from a ```python ... ``` block."""
    match = re.search(r"```(?:python)?\n(.*?)\n```", text, re.S)
    if match:
        return match.group(1).strip()
    return text.strip()

def _decode_test_cases(field: Any) -> List[Dict[str, str]]:
    """
    Robustly decodes LiveCodeBench public/private test-cases.
    This logic is critical for handling the various data formats.
    """
    if isinstance(field, list):
        return field

    if isinstance(field, bytes):
        s = field.decode("utf-8", errors="ignore").strip()
    else:
        s = str(field).strip()

    if s.lstrip().startswith("["):
        try:
            return json.loads(s)
        except json.JSONDecodeError:
            pass # Fall through

    try:
        data = base64.b64decode(s)
        if data.startswith(b'\x78\x9c'): # zlib compressed
            data = zlib.decompress(data)
        
        try: # Try JSON first
            return json.loads(data.decode("utf-8"))
        except: # Fall back to pickle
            return pickle.loads(data)
    except Exception as e:
        raise ValueError(f"Could not decode test case data: {e}") from None

def _run_single_test(python_src: str, test_case: dict, timeout: float) -> bool:
    """Runs a single test case against the provided Python source."""
    with tempfile.TemporaryDirectory() as temp_dir:
        code_path = os.path.join(temp_dir, "main.py")
        with open(code_path, "w", encoding="utf-8") as f:
            f.write(python_src)

        try:
            proc = subprocess.run(
                ["python", code_path],
                input=test_case["input"].encode("utf-8"),
                capture_output=True,
                timeout=timeout,
            )
            # Compare stripped stdout to expected output
            return proc.stdout.decode("utf-8").strip() == test_case["output"].strip()
        except (subprocess.TimeoutExpired, Exception):
            return False

# --- Main function to replace the library call ---

def check_correctness(problem: dict, completion: str, timeout: float = 2.0) -> dict:
    """
    Self-contained function to grade a model's completion for a given problem.

    Args:
        problem: The problem dictionary from the dataset.
        completion: The string response generated by the model.
        timeout: Timeout in seconds for each test case.

    Returns:
        A dictionary with a "passed" boolean key.
    """
    solution_code = _extract_code(completion)
    if not solution_code:
        return {"passed": False}

    try:
        public_tests = _decode_test_cases(problem["public_test_cases"])
        private_tests = _decode_test_cases(problem["private_test_cases"])
        all_tests = public_tests + private_tests
    except ValueError:
        return {"passed": False} # Failed to decode tests

    for test_case in all_tests:
        if not _run_single_test(solution_code, test_case, timeout):
            return {"passed": False} # Failed a test case

    return {"passed": True} # Passed all test cases

def run_livecodebench_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
    n_samples=1,
    test_timeout=2.0,
):
    """
    Refactored LiveCodeBench evaluation using the official lcb_runner package.
    """

    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "livecodebench_eval_refactored.jsonl")

    # === 1) Warm-up (No changes needed here) ===
    # ... (Your warm-up code remains the same) ...
    print("Warm-up complete.")


    # === 2) Main loop (Simplified) ===
    tput_list, tacc_list = [], []
    draft_times, target_times = [], []

    totals, corrects = [], []
    easy_totals, easy_corrects = [], []
    med_totals, med_corrects = [], []
    hard_totals, hard_corrects = [], []

    for i, problem in tqdm(enumerate(dataset), total=len(dataset), desc="Evaluating LiveCodeBench"):
        prompt = problem["prompt"] # Use the prompt from the loaded data

        tokenizer.use_default_system_prompt = True
        input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)

        if input_ids.shape[1] > args.max_length:
            continue

        graded_list = []
        responses = []
        timings = []

        for s in range(n_samples):
            start = time.time()
            # ... (Your generator.generate call remains the same) ...
            if True:
                output_ids = generator.generate(
                    input_ids,
                    temperature=args.temperature,
                    max_length=args.max_length,
                    do_sample=args.do_sample,
                    past_key_values=past_key_values,
                    draft_past_key_values=draft_past_key_values
                )
            gen_time = time.time() - start

            past_key_values.reset()
            if draft_past_key_values is not None:
                draft_past_key_values.reset()

            response = tokenizer.decode(
                output_ids[0][input_ids.shape[1]:], skip_special_tokens=True
            ).strip()
            responses.append(response)

            # !!! KEY CHANGE: Replace all grading logic with one function call !!!
            # The 'problem' dict contains all necessary info (tests, etc.)
            result = check_correctness(problem=problem, completion=response, timeout=test_timeout)
            graded_list.append(result["passed"])
            
            timings.append(gen_time)

        pass1 = int(graded_list[0] if graded_list else 0)
        
        # ... (Your logging and metric accumulation code remains the same) ...
        record = {
            **generator.exp_log,
            "query": prompt,
            "responses": responses,
            "graded_list": graded_list,
            "pass@1": pass1,
            "n": n_samples,
            "platform": problem.get("platform"),
            "difficulty": problem.get("difficulty"),
            "contest_date": problem.get("contest_date"),
            "question_id": problem.get("question_id"),
            "peak_memory": torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)
        }

        # ... (Your metric aggregation and file writing remains the same) ...
        # ...

    # === 3) Summaries (No changes needed here) ===
    # ... (Your summary printing code remains the same) ...
    # ...

    # The function signature expects you to return these values
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0, 0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0, 0)
    avg_draft = np.mean(draft_times) if draft_times else 0
    avg_target = np.mean(target_times) if target_times else 0
    peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

    return (tput_mean, tput_std, tacc_mean, tacc_std, avg_draft, avg_target, peak_memory)


# For pg-19
def run_pg19_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
):
    """
    Multi-batch PG-19 evaluation (throughput-oriented).

    - Build a batch of variable-length tokenized texts, then PAD to (B, Smax).
    - Slice generated outputs using per-sample prompt lengths.
    - Reset KV caches once per batch.
    - No answer_accuracy is computed/returned.

    Assumes:
      - _get_batch_size, _iter_batches are available in scope.
      - generator.generate accepts (B, S) input_ids.
      - dataset yields entries with a text field (default: entry["text"]).
    """

    # 0) Resolve max_new_tokens for PG-19
    # Prefer a PG-19 specific arg if you have one, else fall back.
    # print all args for debugging
    max_new_tokens = getattr(args, "max_new_tokens", 1024)

    # 1) Warm-up (kept single-sample; can be batched too)
    bs = _get_batch_size(generator, default=1)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in tqdm(range(args.warmup_iter), desc="Warmup PG-19"):
        warmup_prompt = ("This is a warmup for PG-19 throughput.\n" * 64) + "Hello world."
        warmup_prompts = [warmup_prompt] * bs

        warmup_ids = tokenizer(
            warmup_prompts, truncation=False, return_tensors="pt"
        ).input_ids.to(generator.device)

        generator.generate(
            warmup_ids,
            temperature=args.temperature,
            max_new_tokens=max_new_tokens,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2) Main evaluation loop (batched)
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "0.jsonl")

    tput_list, tacc_list = [], []
    tput_excl_target_prefill_list, tput_excl_all_prefill_list = [], []
    draft_times, target_times = [], []
    total_q = 0

    min_length = getattr(args, "min_length", 0)

    # pad token handling for manual padding
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = tokenizer.eos_token_id
        tokenizer.pad_token_id = pad_id

    # torch.cuda.cudart().cudaProfilerStart()

    for _, batch in tqdm(list(_iter_batches(dataset, bs)), desc="Evaluating PG-19 (batched)"):
        texts = []
        token_ids_list = []
        prompt_lens = []
        is_real = []

        # ---- Build per-sample token ids first (variable length) ----
        for entry in batch:
            # NOTE: adjust this if your PG-19 loader uses a different key
            text = build_chat_pg19(tokenizer=tokenizer, prompt=entry["text"], model_name=generator.target_model.name_or_path)
            # tokenize single (no padding)
            ids_1d = tokenizer(text, truncation=False, return_tensors="pt").input_ids[0]
            L = int(ids_1d.numel())

            # if L > getattr(args, "max_length", L + 1) or L < min_length:
            #     continue

            # # if exceeds model max_len context, crop half/half (same style as your longbench code)
            # if L > max_len:
            #     half = max_len // 2
            #     ids_1d = torch.cat([ids_1d[:half], ids_1d[-half:]], dim=0)
            #     L = int(ids_1d.numel())

            texts.append(text)
            token_ids_list.append(ids_1d)
            prompt_lens.append(L)
            is_real.append(True)

        real_B = len(token_ids_list)
        if real_B == 0:
            continue

        # ---- Force fixed batch size by padding with EOS-only dummy samples ----
        eos_id = tokenizer.eos_token_id
        if real_B < bs:
            pad_n = bs - real_B
            S_dummy = max(prompt_lens)
            for _ in range(pad_n):
                token_ids_list.append(torch.tensor([eos_id] * S_dummy, dtype=torch.long))
                prompt_lens.append(S_dummy)
                texts.append("")      # dummy
                is_real.append(False)

        # (Optional) truncate to bs if over
        if len(token_ids_list) > bs:
            token_ids_list = token_ids_list[:bs]
            prompt_lens = prompt_lens[:bs]
            texts = texts[:bs]
            is_real = is_real[:bs]

        # ---- Pad to batch tensor (B, Smax) ----
        B = len(token_ids_list)
        Smax = max(prompt_lens)

        input_ids = torch.full((B, Smax), pad_id, dtype=torch.long, device=generator.device)
        for i, ids_1d in enumerate(token_ids_list):
            l = prompt_lens[i]
            input_ids[i, -l:] = ids_1d.to(generator.device)

        # ---- Generate once per batch ----
        output_ids = generator.generate(
            input_ids,
            temperature=args.temperature,
            max_new_tokens=max_new_tokens,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        batch_log = {**getattr(generator, "exp_log", {})}
        peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

        # ---- Per-sample decode + log (no accuracy) ----
        for i in range(B):
            if not is_real[i]:
                continue

            l = prompt_lens[i]

            # IMPORTANT: slice using per-sample prompt length (not padded width)
            # output_ids[i] typically contains [prompt_tokens..., generated_tokens...]
            gen_part = output_ids[i, l:]
            response = tokenizer.decode(gen_part, skip_special_tokens=True)

            record = dict(batch_log)
            record.update(
                {
                    "query": texts[i],
                    "response": response,
                    "prompt_len": int(l),
                    "padded_len": int(input_ids.size(1)),
                    "batch_size": int(B),
                    "peak_memory": float(peak_memory),
                }
            )

            total_q += 1

            # same collectors as longbench (if present)
            if record.get("tput") is not None:
                tput_list.append(record.get("tput", 0))
            if record.get("avg_sampled") is not None:
                tacc_list.append(record.get("avg_sampled", 0))
            if record.get("avg_draft_time") is not None:
                draft_times.append(record.get("avg_draft_time", 0))
            if record.get("avg_target_time") is not None:
                target_times.append(record.get("avg_target_time", 0))
            if record.get("tput_excl_target_prefill") is not None:
                tput_excl_target_prefill_list.append(record.get("tput_excl_target_prefill", 0))
            if record.get("tput_excl_all_prefill") is not None:
                tput_excl_all_prefill_list.append(record.get("tput_excl_all_prefill", 0))

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

        # cleanup
        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()
    
    # torch.cuda.cudart().cudaProfilerStop()

    # 3) Aggregate overall metrics
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0.0, 0.0)
    tput_excl_target_prefill = float(np.mean(tput_excl_target_prefill_list)) if tput_excl_target_prefill_list else 0.0
    tput_excl_all_prefill = float(np.mean(tput_excl_all_prefill_list)) if tput_excl_all_prefill_list else 0.0
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0.0, 0.0)
    avg_draft = float(np.mean(draft_times)) if draft_times else 0.0
    avg_target = float(np.mean(target_times)) if target_times else 0.0
    peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

    # 4) Print summary
    print("Final PG-19 Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tThroughput_excl_target_prefill : {tput_excl_target_prefill:.3f} tokens/sec")
    print(f"\tThroughput_excl_all_prefill    : {tput_excl_all_prefill:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_memory:.3f} GiB")
    if hasattr(generator, "judge_acc_len_list"):
        print(f"\tTacc_judge       : {np.mean(generator.judge_acc_len_list):.3f}")
    else:
        print("\tTacc_judge       : 0.000 (not available)")
    if hasattr(generator, "all_attention_latencies") and generator.all_attention_latencies:
        all_attention_latencies = torch.tensor(generator.all_attention_latencies)
        all_attention_latencies = all_attention_latencies.sum(dim=1)  # sum over layers
        avg_self_attn_latency = np.mean(all_attention_latencies.cpu().numpy())
        print(f"\tAvg Self-Attn Latency: {avg_self_attn_latency:.6f} ms")
    else:
        print("\tAvg Self-Attn Latency: 0.000000 ms (not available)")
    if hasattr(generator, "all_compresskv_latencies") and generator.all_compresskv_latencies:
        avg_compresskv_latency = np.mean(generator.all_compresskv_latencies)
        print(f"\tAvg CompressKV Latency: {avg_compresskv_latency:.6f} ms")
    else:
        print("\tAvg CompressKV Latency: 0.000000 ms (not available)")
    if hasattr(generator, "all_criticality_estimation") and generator.all_criticality_estimation:
        avg_criticality_estimation = np.mean(generator.all_criticality_estimation)
        print(f"\tAvg Criticality Estimation: {avg_criticality_estimation:.6f} ms")

    
    return (
        tput_mean,
        tput_std,
        tput_excl_target_prefill,
        tput_excl_all_prefill,
        tacc_mean,
        tacc_std,
        avg_draft,
        avg_target,
        peak_memory,
    )

# For longbench
def run_longbench_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
    bench_name,
    max_len,
):
    """
    Multi-batch LongBench evaluation.

    Key differences vs your single-batch version:
      - Build a batch of variable-length tokenized prompts, then PAD to (B, Smax).
      - Slice generated outputs using per-sample prompt lengths (NOT padded width).
      - Reset KV caches once per batch.

    Assumes:
      - build_chat_longbench, dataset2metric are defined in the same module scope.
      - generator.generate accepts (B, S) input_ids.
    """
    print("bench name", bench_name)

    # 0) load max_new_tokens for this benchmark
    with open("run/pipelines/benchmarks/utils/config/dataset2maxlen.json", "r", encoding="utf-8") as f:
        benchmark_max_len = json.load(f)

    max_new_tokens = benchmark_max_len.get(bench_name, args.max_length)

    # 1) Warm-up (kept single-sample; can be batched too)
    bs = _get_batch_size(generator, default=1)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in tqdm(range(args.warmup_iter), desc=f"Warmup {bench_name}"):
        warmup_prompt = "Solve this math problem. Give the reasoning steps ...\nWhat is 1 + 1?" * 64
        tokenizer.use_default_system_prompt = True

        if bench_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench_p"]:
            warmup_prompt = build_chat_longbench(tokenizer=tokenizer, prompt=warmup_prompt, model_name=generator.target_model.name_or_path)
        warmup_prompts = [warmup_prompt] * bs

        warmup_ids = tokenizer(warmup_prompts, truncation=False, return_tensors="pt").input_ids.to(generator.device)

        generator.generate(
            warmup_ids,
            temperature=args.temperature,
            max_new_tokens=max_new_tokens,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2) Main evaluation loop (batched)
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "0.jsonl")

    tput_list, tacc_list, tput_excl_target_prefill, tput_excl_all_prefill = [], [], [], []
    draft_times, target_times = [], []
    total_q, correct_q = 0, 0.0  # correct_q is a score sum for LongBench

    min_length = getattr(args, "min_length", 0)

    # pad token handling for manual padding
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = tokenizer.eos_token_id
        tokenizer.pad_token_id = pad_id

    for _, batch in tqdm(list(_iter_batches(dataset, bs)), desc=f"Evaluating {bench_name} (batched)"):
        # ---- Build per-sample token ids first (variable length) ----
        prompts_text = []
        token_ids_list = []
        prompt_lens = []
        gts_list = []
        classes_list = []

        for entry in batch:
            prompt = entry["question"]
            ground_truth_list = entry["answer"]
            all_classes = entry.get("classes", None)

            if bench_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench_p"]:
                prompt = build_chat_longbench(tokenizer=tokenizer, prompt=prompt, model_name=generator.target_model.name_or_path)

            # tokenize single (no padding)
            ids_1d = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]

            L = int(ids_1d.numel())
            if L > args.max_length or L < min_length:
                # match your original "skip it" behavior
                continue

            # if exceeds model max_len context, crop half/half (same as your original)
            if L > max_len:
                half = max_len // 2
                ids_1d = torch.cat([ids_1d[:half], ids_1d[-half:]], dim=0)
                L = int(ids_1d.numel())

            prompts_text.append(prompt)
            token_ids_list.append(ids_1d)
            prompt_lens.append(L)
            gts_list.append(ground_truth_list)
            classes_list.append(all_classes)

        real_B = len(token_ids_list)
        if real_B == 0:
            continue

        # ---- Force fixed batch size by padding with EOS-only dummy samples ----
        eos_id = tokenizer.eos_token_id
        is_real = [True] * real_B

        if real_B < bs:
            pad_n = bs - real_B
            for _ in range(pad_n):
                token_ids_list.append(torch.tensor([eos_id]*max(prompt_lens), dtype=torch.long))
                prompt_lens.append(max(prompt_lens))
                prompts_text.append("")          # dummy
                gts_list.append([])              # dummy (no ground truth)
                classes_list.append(None)
                is_real.append(False)

        # (Optional but recommended) if real_B > bs, truncate to bs to stay consistent
        if len(token_ids_list) > bs:
            token_ids_list = token_ids_list[:bs]
            prompt_lens = prompt_lens[:bs]
            prompts_text = prompts_text[:bs]
            gts_list = gts_list[:bs]
            classes_list = classes_list[:bs]
            is_real = is_real[:bs]

        # ---- Pad to batch tensor (B, Smax) ----
        B = len(token_ids_list)
        Smax = max(prompt_lens)

        input_ids = torch.full((B, Smax), pad_id, dtype=torch.long, device=generator.device)
        for i, ids_1d in enumerate(token_ids_list):
            l = prompt_lens[i]
            input_ids[i, -l:] = ids_1d.to(generator.device)

        # ---- Generate once per batch ----
        output_ids = generator.generate(
            input_ids,
            temperature=args.temperature,
            max_new_tokens=max_new_tokens,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        batch_log = {**getattr(generator, "exp_log", {})}
        peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

        # ---- Per-sample decode + score ----
        for i in range(B):
            if not is_real[i]:
                continue  # skip dummy padded samples entirely
            ground_truth_list = gts_list[i]
            all_classes = classes_list[i]

            # IMPORTANT: slice using per-sample prompt length
            gen_part = output_ids[i, input_ids.size(1):]
            response = tokenizer.decode(gen_part, skip_special_tokens=True)

            if bench_name in ["trec", "triviaqa", "samsum", "lsht"]:
                prediction = response.lstrip("\n").split("\n")[0]
            else:
                prediction = response

            score = 0.0
            for gt in ground_truth_list:
                score = max(score, dataset2metric[bench_name](prediction, gt, all_classes=all_classes))

            total_q += 1
            correct_q += float(score)

            record = dict(batch_log)
            record.update(
                {
                    "query": prompts_text[i],
                    "response": response,
                    "answer": ground_truth_list,
                    "Accuracy": score,
                    "prompt_len": input_ids.size(1),
                    "batch_size": int(B),
                    "peak_memory": float(peak_memory),
                }
            )

            if record.get("tput") is not None:
                tput_list.append(record.get("tput", 0))
            if record.get("avg_sampled") is not None:
                tacc_list.append(record.get("avg_sampled", 0))
            if record.get("avg_draft_time") is not None:
                draft_times.append(record.get("avg_draft_time", 0))
            if record.get("avg_target_time") is not None:
                target_times.append(record.get("avg_target_time", 0))
            if record.get("tput_excl_target_prefill") is not None:
                tput_excl_target_prefill.append(record.get("tput_excl_target_prefill", 0))
            if record.get("tput_excl_all_prefill") is not None:
                tput_excl_all_prefill.append(record.get("tput_excl_all_prefill", 0))

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

        # cleanup
        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3) Aggregate overall metrics
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0, 0)
    tput_excl_target_prefill = np.mean(tput_excl_target_prefill) if tput_excl_target_prefill else 0.0
    tput_excl_all_prefill = np.mean(tput_excl_all_prefill) if tput_excl_all_prefill else 0.0
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0, 0)
    answer_accuracy = round(100.0 * correct_q / total_q, 2) if total_q > 0 else 0.0
    avg_draft = float(np.mean(draft_times)) if draft_times else 0.0
    avg_target = float(np.mean(target_times)) if target_times else 0.0
    peak_memory = torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)

    # 4) Print summary
    print(f"Final {bench_name} Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tThroughput_excl_target_prefill : {np.mean(tput_excl_target_prefill):.3f} tokens/sec")
    print(f"\tThroughput_excl_all_prefill : {np.mean(tput_excl_all_prefill):.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAnswer Accuracy  : {answer_accuracy:.3f} ({correct_q}/{total_q})")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_memory:.3f} GiB")
    if hasattr(generator, "judge_acc_len_list"):
        print(f"\tTacc_judge       : {np.mean(generator.judge_acc_len_list):.3f}")
    else:
        print("\tTacc_judge       : 0.000 (not available)")

    return (
        tput_mean,
        tput_std,
        tput_excl_target_prefill,
        tput_excl_all_prefill,
        tacc_mean,
        tacc_std,
        answer_accuracy,
        avg_draft,
        avg_target,
        peak_memory,
    )

def run_longbenchv2_eval(
    generator,
    tokenizer,
    past_key_values,
    draft_past_key_values,
    args,
    dataset,
    log_dir,
    bench_name,
    max_len,
):
    """
    Multi-batch LongBench-v2 multiple-choice evaluation (supports optional CoT / no_context / RAG).

    Design:
      - We batch only *single-call* generations (non-CoT).
      - If use_cot=True, the evaluation is inherently 2-stage and depends on model output,
        so we keep it single-sample (still correct, but not batched).
        If you want, we can batch stage-1 and stage-2 separately later with careful bookkeeping.

    Assumes these symbols exist in outer scope:
      template_rag, template_no_context, template_0shot, template_0shot_cot, template_0shot_cot_ans
      build_input_ids(prompt, tokenizer, device, max_len, args) -> (tokenized_1d_or_None, actual_len)
      extract_longbenchv2_answer(str) -> str
    """

    # 0) parse filters from bench_name
    split_bench_name = bench_name.split("-")
    if len(split_bench_name) == 3:
        bench_name, length_filter, diff_filter = split_bench_name
    else:
        length_filter, diff_filter = "overall", "overall"

    def _keep(item):
        ok = True
        if length_filter != "overall":
            ok = ok and str(item.get("length", "")).lower() == length_filter
        if diff_filter != "overall":
            ok = ok and str(item.get("difficulty", "")).lower() == diff_filter
        return ok

    filtered_dataset = [it for it in dataset if _keep(it)]
    print(
        f"LongBench-v2 filter: length={length_filter}, difficulty={diff_filter}, "
        f"{len(filtered_dataset)}/{len(dataset)} samples kept"
    )
    if len(filtered_dataset) == 0:
        print("WARNING: no samples after filtering, return zeros.")
        return 0, 0, 0, 0, 0.0, 0.0, 0.0, 0.0

    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "0.jsonl")

    # 1) Warm-up 
    bs = _get_batch_size(generator, default=1)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in tqdm(range(getattr(args, "warmup_iter", 0)), desc="Warmup LongBench-v2"):
        warmup_prompt = (
            "You are given a long document and a multiple-choice question.\n"
            "Read the document carefully and answer with only the letter of the correct option.\n\n"
            "DOCUMENT:\n"
            + ("This is a dummy document. " * 32)
            + "\n\nQUESTION:\nWhat is 1 + 1?\n\n"
            "OPTIONS:\nA. 0\nB. 1\nC. 2\nD. 3\n\nAnswer:"
        )
        warmup_tokenized, _ = build_input_ids(
            prompt=warmup_prompt,
            tokenizer=tokenizer,
            device=generator.device,
            max_len=max_len,
            args=args,
        )
        if warmup_tokenized is None:
            continue

        # warmup_ids = warmup_tokenized.unsqueeze(0)
        # duplicate to batch size
        warmup_ids = warmup_tokenized.unsqueeze(0).repeat(bs, 1)
        warmup_ids = warmup_ids.to(generator.device)
        generator.generate(
            warmup_ids,
            temperature=args.temperature,
            max_new_tokens=128,
            do_sample=args.do_sample,
            past_key_values=past_key_values,
            draft_past_key_values=draft_past_key_values,
        )
        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2) Main loop
    tput_list, tacc_list = [], []
    draft_times, target_times = [], []
    total_q, correct_q = 0, 0

    use_cot = getattr(args, "cot", False)
    use_no_context = getattr(args, "no_context", False)
    rag_topk = getattr(args, "rag", 0)

    # pad token handling for manual padding
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = tokenizer.eos_token_id
        tokenizer.pad_token_id = pad_id

    # ----------------------------
    # Case A: CoT mode (keep single-sample; 2-stage dependency)
    # ----------------------------
    if use_cot:
        for _, item in tqdm(enumerate(filtered_dataset), total=len(filtered_dataset), desc="Evaluating LongBench-v2 (CoT, single)"):
            context = item["context"]

            # template selection
            max_new_tokens = 1024
            if rag_topk > 0 and "retrieved_context" in item:
                template = template_rag
                retrieved = item["retrieved_context"][:rag_topk]
                retrieved = sorted(retrieved, key=lambda x: x["c_idx"])
                context_used = "\n\n".join(
                    [f"Retrieved chunk {i+1}: {x['content']}" for i, x in enumerate(retrieved)]
                )
            elif use_no_context:
                template = template_no_context
                context_used = ""
            else:
                template = template_0shot_cot
                context_used = context

            prompt = (
                template.replace("$DOC$", context_used.strip())
                .replace("$Q$", item["question"].strip())
                .replace("$C_A$", item["choice_A"].strip())
                .replace("$C_B$", item["choice_B"].strip())
                .replace("$C_C$", item["choice_C"].strip())
                .replace("$C_D$", item["choice_D"].strip())
            )

            tokenized_prompt, actual_len = build_input_ids(
                prompt=prompt,
                tokenizer=tokenizer,
                device=generator.device,
                max_len=max_len,
                args=args,
            )
            if tokenized_prompt is None:
                continue

            input_ids = tokenized_prompt.unsqueeze(0).to(generator.device)

            # stage 1: generate CoT
            out1 = generator.generate(
                input_ids,
                temperature=args.temperature,
                max_new_tokens=max_new_tokens,
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values,
            )
            past_key_values.reset()
            if draft_past_key_values is not None:
                draft_past_key_values.reset()

            response_cot = tokenizer.decode(out1[0, input_ids.size(1):], skip_special_tokens=True).strip()

            # stage 2: answer only
            prompt2 = (
                template_0shot_cot_ans.replace("$DOC$", context_used.strip())
                .replace("$Q$", item["question"].strip())
                .replace("$C_A$", item["choice_A"].strip())
                .replace("$C_B$", item["choice_B"].strip())
                .replace("$C_C$", item["choice_C"].strip())
                .replace("$C_D$", item["choice_D"].strip())
                .replace("$COT$", response_cot)
            )

            tokenized_prompt2, actual_len2 = build_input_ids(
                prompt=prompt2,
                tokenizer=tokenizer,
                device=generator.device,
                max_len=max_len,
                args=args,
            )
            if tokenized_prompt2 is None:
                continue

            input_ids2 = tokenized_prompt2.unsqueeze(0).to(generator.device)

            out2 = generator.generate(
                input_ids2,
                temperature=args.temperature,
                max_new_tokens=128,
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values,
            )
            past_key_values.reset()
            if draft_past_key_values is not None:
                draft_past_key_values.reset()

            response = tokenizer.decode(out2[0, input_ids2.size(1):], skip_special_tokens=True).strip()

            pred = extract_longbenchv2_answer(response)
            gt = item["answer"].strip()
            judge = (pred == gt)

            total_q += 1
            correct_q += int(judge)

            record = dict(getattr(generator, "exp_log", {}))
            record.update(
                {
                    "_id": item.get("_id"),
                    "domain": item.get("domain"),
                    "sub_domain": item.get("sub_domain"),
                    "difficulty": item.get("difficulty"),
                    "length": item.get("length"),
                    "actual_length": actual_len,
                    "query": item["question"],
                    "response": response,
                    "response_cot": response_cot,
                    "answer": gt,
                    "pred": pred,
                    "judge": judge,
                    "Accuracy": int(judge),
                    "context": context_used[:1000],
                    "peak_memory": float(torch.cuda.max_memory_reserved(generator.device) / (1024**3)),
                    "batch_size": 1,
                    "cot": True,
                }
            )

            if record.get("tput") is not None:
                tput_list.append(record["tput"])
            if record.get("avg_sampled") is not None:
                tacc_list.append(record["avg_sampled"])
            if record.get("avg_draft_time") is not None:
                draft_times.append(record["avg_draft_time"])
            if record.get("avg_target_time") is not None:
                target_times.append(record["avg_target_time"])

            with open(log_file, "a+", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")

            del input_ids, input_ids2, out1, out2
            gc.collect()
            torch.cuda.empty_cache()

    # ----------------------------
    # Case B: non-CoT mode (batched)
    # ----------------------------
    else:

        for _, batch in tqdm(list(_iter_batches(filtered_dataset, bs)), desc="Evaluating LongBench-v2 (batched)"):
            # build per-sample token ids (variable length) + metadata
            token_ids_list = []
            prompt_lens = []
            actual_lens = []
            items_kept = []
            context_used_list = []
            max_new_tokens = 128  # constant in non-CoT path

            for item in batch:
                context = item["context"]

                # template selection
                if rag_topk > 0 and "retrieved_context" in item:
                    template = template_rag
                    retrieved = item["retrieved_context"][:rag_topk]
                    retrieved = sorted(retrieved, key=lambda x: x["c_idx"])
                    context_used = "\n\n".join(
                        [f"Retrieved chunk {i+1}: {x['content']}" for i, x in enumerate(retrieved)]
                    )
                elif use_no_context:
                    template = template_no_context
                    context_used = ""
                else:
                    template = template_0shot
                    context_used = context

                prompt = (
                    template.replace("$DOC$", context_used.strip())
                    .replace("$Q$", item["question"].strip())
                    .replace("$C_A$", item["choice_A"].strip())
                    .replace("$C_B$", item["choice_B"].strip())
                    .replace("$C_C$", item["choice_C"].strip())
                    .replace("$C_D$", item["choice_D"].strip())
                )

                tokenized_prompt, actual_len = build_input_ids(
                    prompt=prompt,
                    tokenizer=tokenizer,
                    device=generator.device,
                    max_len=max_len,
                    args=args,
                )
                if tokenized_prompt is None:
                    continue

                token_ids_list.append(tokenized_prompt)  # 1D
                prompt_lens.append(int(tokenized_prompt.numel()))
                actual_lens.append(int(actual_len) if actual_len is not None else int(tokenized_prompt.numel()))
                items_kept.append(item)
                context_used_list.append(context_used)

            if len(token_ids_list) == 0:
                continue

            B = len(token_ids_list)
            Smax = max(prompt_lens)

            input_ids = torch.full((B, Smax), pad_id, dtype=torch.long, device=generator.device)
            for i, ids_1d in enumerate(token_ids_list):
                l = prompt_lens[i]
                input_ids[i, :l] = ids_1d.to(generator.device)

            output_ids = generator.generate(
                input_ids,
                temperature=args.temperature,
                max_new_tokens=max_new_tokens,
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values,
            )

            past_key_values.reset()
            if draft_past_key_values is not None:
                draft_past_key_values.reset()

            batch_log = dict(getattr(generator, "exp_log", {}))
            peak_memory = float(torch.cuda.max_memory_reserved(generator.device) / (1024**3))

            for i in range(B):
                item = items_kept[i]
                pl = prompt_lens[i]

                response = tokenizer.decode(output_ids[i, pl:], skip_special_tokens=True).strip()
                pred = extract_longbenchv2_answer(response)
                gt = item["answer"].strip()
                judge = (pred == gt)

                total_q += 1
                correct_q += int(judge)

                record = dict(batch_log)
                record.update(
                    {
                        "_id": item.get("_id"),
                        "domain": item.get("domain"),
                        "sub_domain": item.get("sub_domain"),
                        "difficulty": item.get("difficulty"),
                        "length": item.get("length"),
                        "actual_length": actual_lens[i],
                        "query": item["question"],
                        "response": response,
                        "response_cot": None,
                        "answer": gt,
                        "pred": pred,
                        "judge": judge,
                        "Accuracy": int(judge),
                        "context": context_used_list[i][:1000],
                        "peak_memory": peak_memory,
                        "batch_size": int(B),
                        "cot": False,
                    }
                )

                if record.get("tput") is not None:
                    tput_list.append(record["tput"])
                if record.get("avg_sampled") is not None:
                    tacc_list.append(record["avg_sampled"])
                if record.get("avg_draft_time") is not None:
                    draft_times.append(record["avg_draft_time"])
                if record.get("avg_target_time") is not None:
                    target_times.append(record["avg_target_time"])

                with open(log_file, "a+", encoding="utf-8") as f:
                    json.dump(record, f, ensure_ascii=False)
                    f.write("\n")

            del input_ids, output_ids
            gc.collect()
            torch.cuda.empty_cache()

    # 3) Aggregate
    tput_mean, tput_std = ((np.mean(tput_list), np.std(tput_list)) if tput_list else (0.0, 0.0))
    tacc_mean, tacc_std = ((np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0.0, 0.0))
    answer_accuracy = round(100.0 * correct_q / total_q, 2) if total_q > 0 else 0.0
    avg_draft = float(np.mean(draft_times)) if draft_times else 0.0
    avg_target = float(np.mean(target_times)) if target_times else 0.0
    peak_memory = float(torch.cuda.max_memory_reserved(generator.device) / (1024**3))

    print(f"Final LongBench-v2 ({bench_name}) Results:")
    print(f"\tFilter length={length_filter}, difficulty={diff_filter}")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAnswer Accuracy  : {answer_accuracy:.2f} ({correct_q}/{total_q})")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_memory:.3f} GiB")
    if hasattr(generator, "judge_acc_len_list"):
        print(f"\tTacc_judge       : {np.mean(generator.judge_acc_len_list):.3f}")
    else:
        print("\tTacc_judge       : 0.000 (not available)")

    return (
        tput_mean,
        tput_std,
        tacc_mean,
        tacc_std,
        answer_accuracy,
        avg_draft,
        avg_target,
        peak_memory,
    )

def run_longgenbench_eval(generator, tokenizer, past_key_values, draft_past_key_values, args, dataset, log_dir, length_tag = "short"):
    """
    Evaluate longgenbench dataset accuracy alongside performance metrics.
    Ex. "longgenbench" from https://github.com/mozhu621/LongGenBench

    Args:
        generator: the model generator instance
        tokenizer: tokenizer with chat template functionality
        past_key_values: primary past key values for autoregressive generation
        draft_past_key_values: draft past key values for speculative decoding (optional)
        args: namespace containing temperature, max_length, do_sample, warmup_iter
        dataset: list of dicts, each with keys:
            "question": the prompt string
            "answer": full original answer text from GSM8K (with reasoning and final line "Answer: N")
        log_dir: directory path for writing per-sample JSONL logs
        length_tag: version of input, short: 16k, long: 32k
        max_len: max_len of LLM Ex. llama3: 127500

    Returns:
        A tuple of metrics:
        (tput_mean, tput_std, tacc_mean, tacc_std,
         cr_mean, stic1_mean, stic2_mean, wavg_mean,
         answer_accuracy, avg_draft_time, avg_target_time, peak_memory)
    """
    if length_tag == "short":
        max_new_tokens = 16 * 1024
    else: # else long
        max_new_tokens = 32 * 1024

    # 1. Warm-up (identical to original implementation)
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in range(args.warmup_iter):
        warmup_prompt = "Solve this math problem. Give the reasoning steps ...\nWhat is 1 + 1?" * 32
        tokenizer.use_default_system_prompt = True
        warmup_ids = tokenizer.apply_chat_template(
            [{"role":"user","content":warmup_prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)
        if True:
            generator.generate(
                warmup_ids,
                temperature=args.temperature,
                max_new_tokens=max_new_tokens,
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values
            )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2. Main evaluation loop
    log_file = os.path.join(log_dir, "0.jsonl")

    # Lists to accumulate throughput, token acceptance, draft/target times
    tput_list = []
    tacc_list = []  # average token acceptance rate per sample
    draft_times = []
    target_times = []
    per_sample = []  # add per_sample for detail score

    for idx, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Evaluating LongGenBench"):
        prompt = entry["question"]
        meta = {
            "checks_once": entry.get("checks_once", {}),
            "checks_range": entry.get("checks_range", {}),
            "checks_periodic": entry.get("checks_periodic", {}),
            "prefix": entry.get("prefix", ""),
            "number": entry.get("number", 0),
        }

        tokenizer.use_default_system_prompt = True
        input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)
        
        if True:
            output_ids = generator.generate(
                input_ids,
                temperature=args.temperature,
                max_new_tokens=max_new_tokens,
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values
            )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        response = tokenizer.decode(
            output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
        )

        # Add per-sample correctness
        score = score_longgenbench_single(response, meta)
        per_sample.append(score)

        record = {**getattr(generator, "exp_log", {})}
        record.update({
            "query": prompt,
            "response": response,
            "meta": meta,
            "CR": score["cr"],
            "STIC1_once": score["stic1_once"],
            "STIC1_range": score["stic1_range"],
            "STIC1_periodic": score["stic1_periodic"],
            "STIC1_overall": score["stic1_overall"],
            "STIC2_once": score["stic2_once"],
            "STIC2_range": score["stic2_range"],
            "STIC2_periodic": score["stic2_periodic"],
            "STIC2_overall": score["stic2_overall"],
            "wAvg": score["wavg"],
            "peak_memory": float(torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)),
        })

        # Append metrics lists
        if record.get("tput") is not None:
            tput_list.append(record.get("tput", 0))
        if record.get("avg_sampled") is not None:
            tacc_list.append(record.get("avg_sampled", 0))
        if record.get("avg_draft_time") is not None:
            draft_times.append(record.get("avg_draft_time", 0))
        if record.get("avg_target_time") is not None:
            target_times.append(record.get("avg_target_time", 0))

        # Write JSONL entry
        with open(log_file, "a+", encoding="utf-8") as f:
            json.dump(record, f, ensure_ascii=False)
            f.write("\n")

        # Clean up
        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3. Aggregate overall metrics
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0.0, 0.0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0.0, 0.0)
    cr_mean = float(np.mean([s["cr"] for s in per_sample])) if per_sample else 0.0
    stic1_mean = float(np.mean([s["stic1_overall"] for s in per_sample])) if per_sample else 0.0
    wavg_mean = float(np.mean([s["wavg"] for s in per_sample])) if per_sample else 0.0
    stic2_mean = float(np.mean([s["stic2_overall"] for s in per_sample])) if per_sample else 0.0
    avg_draft = float(np.mean(draft_times)) if draft_times else 0.0
    avg_target = float(np.mean(target_times)) if target_times else 0.0
    peak_mem = float(torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3))

    # 4. Print summary
    print(f"Final LongGenBench-{length_tag} Results:")
    print(f"\tCR (Main Task Completion) : {cr_mean:.3f}")
    print(f"\tSTIC-1 (overall)          : {stic1_mean:.3f}")
    print(f"\tSTIC-2 (overall)          : {stic2_mean:.3f}")
    print(f"\twAvg = CR × STIC-2        : {wavg_mean:.3f}")
    print(f"\tThroughput                : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance          : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAvg Draft Time            : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time           : {avg_target:.3f} sec")
    print(f"\tPeak Memory               : {peak_mem:.3f} GiB")

    # 5. Return metrics tuple
    return (
        tput_mean, tput_std,
        tacc_mean, tacc_std,
        {"cr_mean": cr_mean, "stic1_mean": stic1_mean, "stic2_mean": stic2_mean, "wavg_mean": wavg_mean},        # Accuracy
        avg_draft, avg_target,
        peak_mem
    )

def run_longwriter_eval(generator, tokenizer, past_key_values, draft_past_key_values, args, dataset, log_dir):
    """
    Evaluate LongWriter dataset for long-form generation throughput and acceptance rate.
    """
    max_new_tokens = 16384  
    
    # 1. Warm-up
    original_profiling = generator.profiling
    generator.profiling = False
    for _ in range(getattr(args, "warmup_iter", 1)):
        warmup_prompt = "Write a short story about a cat."
        tokenizer.use_default_system_prompt = True
        warmup_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": warmup_prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)
        
        if True:
            generator.generate(
                warmup_ids,
                temperature=args.temperature,
                max_new_tokens=128, # Warmup do not need too long generation
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values
            )
        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()
    generator.profiling = original_profiling

    # 2. Main evaluation loop
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, "0.jsonl")
    
    # get NaiveBuilder length as baseline
    naive_file = log_file.split('/')
    # if naive_file[2] == 'NaiveBuilder', do nothing
    # else read naive_file to get naive lengths
    if log_file.split('/')[2] != 'NaiveBuilder':
        naive_file[2] = 'NaiveBuilder'
        naive_file[3] = 'no_draft'
        naive_file = os.path.join(*naive_file)
        naive_lengths = []
        with open(naive_file, 'r', encoding='utf-8') as f:
            for line in f:
                record = json.loads(line)
                naive_lengths.append(record['n_tokens'])
    
    tput_list, tacc_list = [], []
    draft_times, target_times = [], []
    ratio_list = []
    
    # LongWriter specific: record output lengths
    output_lens = []

    for idx, entry in tqdm(enumerate(dataset), total=len(dataset), desc="Evaluating LongWriter"):
        prompt = entry["question"]
        
        tokenizer.use_default_system_prompt = True
        input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(generator.device)

        # If input is too long to generate enough output, you can choose to skip or truncate
        if input_ids.shape[1] >= args.max_length:
            continue

        if True:
            # Use fixed long generation settings
            output_ids = generator.generate(
                input_ids,
                temperature=args.temperature,
                max_new_tokens=max_new_tokens, 
                do_sample=args.do_sample,
                past_key_values=past_key_values,
                draft_past_key_values=draft_past_key_values
            )

        past_key_values.reset()
        if draft_past_key_values is not None:
            draft_past_key_values.reset()

        response = tokenizer.decode(
            output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
        )
        
        gen_len = len(output_ids[0]) - input_ids.shape[1]
        output_lens.append(gen_len)

        record = {**getattr(generator, "exp_log", {})}
        record.update({
            "response_ratio": round(min(gen_len / naive_lengths[idx], 1.0), 2) if log_file.split('/')[2] != 'NaiveBuilder' else 1.0,
            "query": str(prompt[:200]) + "...",  # Only store the beginning to avoid large logs
            "response_snippet": response[:200] + "...", # Only store the beginning to avoid large logs
            "peak_memory": float(torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3)),
        })

        if record.get("tput") is not None: tput_list.append(record["tput"])
        if record.get("avg_sampled") is not None: tacc_list.append(record["avg_sampled"])
        if record.get("response_ratio") is not None: ratio_list.append(record["response_ratio"])
        if record.get("avg_draft_time") is not None: draft_times.append(record["avg_draft_time"])
        if record.get("avg_target_time") is not None: target_times.append(record["avg_target_time"])

        with open(log_file, "a+", encoding="utf-8") as f:
            json.dump(record, f, ensure_ascii=False)
            f.write("\n")

        del input_ids, output_ids
        gc.collect()
        torch.cuda.empty_cache()

    # 3. Aggregate metrics
    tput_mean, tput_std = (np.mean(tput_list), np.std(tput_list)) if tput_list else (0.0, 0.0)
    tacc_mean, tacc_std = (np.mean(tacc_list), np.std(tacc_list)) if tacc_list else (0.0, 0.0)
    avg_draft = float(np.mean(draft_times)) if draft_times else 0.0
    avg_target = float(np.mean(target_times)) if target_times else 0.0
    avg_len_ratio = float(np.mean(ratio_list)) if ratio_list else 0.0
    peak_mem = float(torch.cuda.max_memory_reserved(generator.device) / (1024 ** 3))

    print(f"Final LongWriter Results:")
    print(f"\tThroughput       : {tput_mean:.3f} ± {tput_std:.3f} tokens/sec")
    print(f"\tToken Acceptance : {tacc_mean:.3f} ± {tacc_std:.3f}")
    print(f"\tAvg Output Ratio : {avg_len_ratio:.2f}")
    print(f"\tAvg Draft Time   : {avg_draft:.3f} sec")
    print(f"\tAvg Target Time  : {avg_target:.3f} sec")
    print(f"\tPeak Memory      : {peak_mem:.3f} GiB")

    return (
        tput_mean, tput_std,
        tacc_mean, tacc_std,
        avg_len_ratio, 
        avg_draft, avg_target,
        peak_mem
    )