#!/usr/bin/env python3
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer
from statistics import mean
from tqdm import tqdm
import csv

# English tasks that exist in THUDM/LongBench
EN_TASKS = [
    "hotpotqa",
    "2wikimqa",
    "musique",
    "multifieldqa_en",
    "narrativeqa",
    "qasper",
    "gov_report",
    "qmsum",
    "multi_news",
    "triviaqa",
    "trec",
    "samsum",
    "passage_count",
    "passage_retrieval_en",
    # Uncomment if you want code tasks too:
    # "lcc", "repobench-p",
]

INPUT_FIELDS   = ["input", "question", "query", "prompt"]
CONTEXT_FIELDS = ["context", "document", "passage", "passages", "docs", "documents", "evidence"]
OUTPUT_FIELDS  = ["output", "answer", "answers", "label", "target", "gold"]


def _coerce_to_text(v):
    if v is None:
        return ""
    if isinstance(v, str):
        return v
    if isinstance(v, (list, tuple)):
        return "\n".join(str(x) for x in v)
    return str(v)


def extract_part(example, candidate_fields):
    for f in candidate_fields:
        if f in example and example[f] not in (None, "", [], ()):
            text = _coerce_to_text(example[f]).strip()
            if text:
                return text
    return ""


def build_prompt(example, include_input=True, include_context=True, include_output=True,
                 input_prefix="### Input:\n", context_prefix="### Context:\n",
                 output_prefix="### Output:\n", separator="\n\n"):
    parts = []
    if include_input:
        t_in = extract_part(example, INPUT_FIELDS)
        if t_in:
            parts.append(f"{input_prefix}{t_in}")
    if include_context:
        t_ctx = extract_part(example, CONTEXT_FIELDS)
        if t_ctx:
            parts.append(f"{context_prefix}{t_ctx}")
    if include_output:
        t_out = extract_part(example, OUTPUT_FIELDS)
        if t_out:
            parts.append(f"{output_prefix}{t_out}")
    return separator.join(parts).strip()


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, required=True,
                    help="HF tokenizer checkpoint (e.g. meta-llama/Llama-3.1-8B)")
    ap.add_argument("--split", type=str, default="test", help="Dataset split")
    ap.add_argument("--csv_out", type=str, default="longbench_en_token_lengths.csv", help="CSV path")
    ap.add_argument("--tasks", type=str, nargs="*", default=EN_TASKS, help="Tasks to analyze")

    # Defaults: include everything; use flags to EXCLUDE parts
    ap.add_argument("--exclude_context", action="store_true", help="Exclude context from prompt")
    ap.add_argument("--exclude_output", action="store_true", help="Exclude output from prompt")

    # Chat template options (off by default)
    ap.add_argument("--use_chat_template", action="store_true",
                    help="Render via tokenizer.apply_chat_template before tokenizing")
    ap.add_argument("--add_generation_prompt", action="store_true",
                    help="When using chat template")

    # Threshold for "long" samples
    ap.add_argument("--threshold", type=int, default=8192,
                    help="Count samples with token length > threshold (default: 8192)")

    args = ap.parse_args()

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    if getattr(tok, "model_max_length", None) and tok.model_max_length < 10**9:
        tok.model_max_length = 10**9
    if tok.pad_token is None and tok.eos_token is not None:
        tok.pad_token = tok.eos_token

    rows = []
    for task in args.tasks:
        print(f"\n=== {task} ===")
        ds = load_dataset("THUDM/LongBench", task, split=args.split)

        lengths = []
        count_over_threshold = 0
        for ex in tqdm(ds, desc=f"Tokenizing {task}"):
            text = build_prompt(
                ex,
                include_input=True,
                include_context=not args.exclude_context,
                include_output=not args.exclude_output
            )

            if args.use_chat_template:
                rendered = tok.apply_chat_template(
                    [{"role": "user", "content": text}],
                    add_generation_prompt=args.add_generation_prompt,
                    tokenize=False
                )
                enc = tok(rendered, add_special_tokens=False, truncation=False)
            else:
                enc = tok(text, add_special_tokens=True, truncation=False)

            token_len = len(enc["input_ids"])
            lengths.append(token_len)

            if token_len > args.threshold:
                count_over_threshold += 1

        avg_len = float(mean(lengths)) if lengths else 0.0
        max_len = int(max(lengths)) if lengths else 0
        n = len(lengths)

        print(f"samples: {n}")
        print(f"avg tokens: {avg_len:,.2f}")
        print(f"max tokens: {max_len:,}")
        print(f"> {args.threshold} tokens: {count_over_threshold} samples")

        rows.append({
            "task": task,
            "split": args.split,
            "samples": n,
            "avg_token_len": f"{avg_len:.2f}",
            "max_token_len": max_len,
            f"num_over_{args.threshold}": count_over_threshold
        })

    with open(args.csv_out, "w", newline="", encoding="utf-8") as f:
        fieldnames = ["task", "split", "samples", "avg_token_len", "max_token_len", f"num_over_{args.threshold}"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    print(f"\nSaved CSV to: {args.csv_out}")


if __name__ == "__main__":
    main()
