#!/usr/bin/env python3
import matplotlib.pyplot as plt
import argparse
from collections import defaultdict
import csv
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple

import matplotlib
import numpy as np
from matplotlib.ticker import FuncFormatter, MaxNLocator

matplotlib.use("Agg")


FIG_BG_COLOR = "#FFFFFF"
BG_COLOR = "#FBFCFE"
GRID_COLOR = "#D9DEE8"
SPINE_COLOR = "#2F3441"
TEXT_COLOR = "#141821"

COLOR_BASELINE = "#00468B"
COLOR_GSPO_LENGTH = "#9B59B6"
COLOR_GSPO = "#2E8B57"

TITLE_FONT_SIZE = 35
AXIS_LABEL_FONT_SIZE = 25
TICK_FONT_SIZE = 20
LEGEND_FONT_SIZE = 25


DEFAULT_TOKENIZER_PATH = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base"
DEFAULT_INPUTS = [
    # (
    #     "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/Qwen3-4B-Base-aime128_32768_test.jsonl",
    #     "Qwen3-4B-Base",
    # ),
    # (
    #     "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/LIE-aime128_32768_test.jsonl",
    #     "GSPO + skip-right (step600)",
    # ),
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/gspo-baseline-aime128_32768_test.jsonl",
    #     "GSPO (step500)"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_nov/Qwen3-4B-Base-valid-all_32768_test.jsonl", "Qwen3-4B-Base"),
    # (
    #     # "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/gspo-baseline-aime128_32768_test.jsonl",
    #     # "GSPO (step500)",
    #     "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-add1k-wo-repetition-step570-valid-all_32768_test.jsonl",
    #     "GSPO + LIE (add1k-wo-repetition)",
    # ),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-step500-valid-all_32768_test.jsonl", "GSPO"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/gspo-skip-right-step600-valid-all_32768_test.jsonl", "GSPO + LINE")
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/octothinker-valid-all_32768_test.jsonl", "OctoThinker"),
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/octothinker-gspo-step300-valid-all_32768_test.jsonl", "OctoThinker + GSPO (step300)"),
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/octothinker-add1k-step360-valid-all_32768_test.jsonl", "OctoThinker + add1k (step360)")
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/LIE-semantic-v6-280_32768_test.jsonl", "LIE-semantic-v6-280")
    # ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_mar/llama-length-310_32768_test.jsonl", "llama-length-310")
]
DEFAULT_BUDGETS = [32768]
DEFAULT_DISTRIBUTION_BIN_WIDTH = 512


@dataclass
class BudgetStats:
    budget: int
    average_capped_length: float
    clipped_cases: int
    total_cases: int


@dataclass
class FileStats:
    path: str
    label: str
    total_cases: int
    average_raw_length: float
    budgets: List[BudgetStats]
    length_frequencies: Dict[int, Dict[int, int]]


class TokenizersWrapper:
    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer

    def __call__(
        self,
        texts: Sequence[str],
        add_special_tokens: bool = False,
        padding: bool = False,
        truncation: bool = False,
    ) -> Dict[str, List[List[int]]]:
        if padding or truncation:
            raise ValueError(
                "TokenizersWrapper does not support padding/truncation.")
        encodings = self.tokenizer.encode_batch(
            list(texts),
            add_special_tokens=add_special_tokens,
        )
        return {"input_ids": [encoding.ids for encoding in encodings]}


def add_aligned_legend(ax):
    """与 baseline 折线图保持一致的 legend 样式。"""
    legend = ax.legend(
        frameon=True,
        fontsize=LEGEND_FONT_SIZE,
        prop={"weight": "bold", "size": LEGEND_FONT_SIZE},
    )
    if legend is not None:
        legend.get_frame().set_alpha(0.95)
    return legend


def load_tokenizer(tokenizer_path_or_name: str):
    transformers_error: Exception | None = None
    try:
        from transformers import AutoTokenizer

        return AutoTokenizer.from_pretrained(tokenizer_path_or_name)
    except Exception as exc:
        transformers_error = exc

    tokenizer_json = Path(tokenizer_path_or_name) / "tokenizer.json"
    if tokenizer_json.exists():
        try:
            from tokenizers import Tokenizer

            backend = Tokenizer.from_file(str(tokenizer_json))
            return TokenizersWrapper(backend)
        except Exception as exc:
            raise RuntimeError(
                f"Failed to load tokenizer from {tokenizer_json}"
            ) from exc

    raise RuntimeError(
        "Failed to load tokenizer with transformers, and no local tokenizer.json "
        f"fallback was found under: {tokenizer_path_or_name}\n"
        f"Original transformers error: {transformers_error}"
    )


def parse_input_spec(spec: str) -> Tuple[str, str]:
    for sep in ("::", "=", "|"):
        if sep in spec:
            path, label = spec.split(sep, 1)
            return path.strip(), label.strip()
    path = spec.strip()
    label = Path(path).stem
    return path, label


def load_jsonl(path: str) -> Iterable[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                item = json.loads(line)
            except json.JSONDecodeError as exc:
                raise ValueError(f"Invalid JSON in {path}:{line_no}") from exc
            if not isinstance(item, dict):
                continue
            yield item


def normalize_text(value: object) -> str:
    if value is None:
        return ""
    if isinstance(value, list):
        if not value:
            return ""
        value = value[0]
    if not isinstance(value, str):
        value = str(value)
    return value.replace("<|endoftext|>", "").strip()


def extract_response_text(item: Dict) -> str:
    for key in ("generated_text", "output", "response", "text", "completion"):
        if key in item:
            text = normalize_text(item.get(key))
            if text:
                return text
    return ""


def batched_token_lengths(texts: Sequence[str], tokenizer) -> List[int]:
    encodings = tokenizer(
        list(texts),
        add_special_tokens=False,
        padding=False,
        truncation=False,
    )
    return [len(ids) for ids in encodings["input_ids"]]


def compute_file_stats(
    path: str,
    label: str,
    tokenizer,
    budgets: Sequence[int],
    batch_size: int,
) -> FileStats:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")

    total_cases = 0
    raw_length_sum = 0
    budget_length_sums = {budget: 0 for budget in budgets}
    budget_clipped_counts = {budget: 0 for budget in budgets}
    budget_length_frequencies = {
        budget: defaultdict(int) for budget in budgets
    }

    batch_texts: List[str] = []
    for item in load_jsonl(path):
        text = extract_response_text(item)
        if not text:
            continue
        batch_texts.append(text)
        if len(batch_texts) < batch_size:
            continue

        lengths = batched_token_lengths(batch_texts, tokenizer=tokenizer)
        for length in lengths:
            total_cases += 1
            raw_length_sum += length
            for budget in budgets:
                capped_length = min(length, budget)
                budget_length_sums[budget] += capped_length
                budget_length_frequencies[budget][capped_length] += 1
                if length > budget:
                    budget_clipped_counts[budget] += 1
        batch_texts = []

    if batch_texts:
        lengths = batched_token_lengths(batch_texts, tokenizer=tokenizer)
        for length in lengths:
            total_cases += 1
            raw_length_sum += length
            for budget in budgets:
                capped_length = min(length, budget)
                budget_length_sums[budget] += capped_length
                budget_length_frequencies[budget][capped_length] += 1
                if length > budget:
                    budget_clipped_counts[budget] += 1

    if total_cases == 0:
        raise ValueError(f"No valid response text found in {path}")

    average_raw_length = raw_length_sum / total_cases

    budget_stats: List[BudgetStats] = []
    for budget in budgets:
        budget_stats.append(
            BudgetStats(
                budget=budget,
                average_capped_length=budget_length_sums[budget] / total_cases,
                clipped_cases=budget_clipped_counts[budget],
                total_cases=total_cases,
            )
        )

    return FileStats(
        path=path,
        label=label,
        total_cases=total_cases,
        average_raw_length=average_raw_length,
        budgets=budget_stats,
        length_frequencies={
            budget: dict(sorted(length2count.items()))
            for budget, length2count in budget_length_frequencies.items()
        },
    )


def save_csv(results: Sequence[FileStats], out_csv: Path) -> None:
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    header = (
        "label,path,total_cases,average_raw_length,budget,"
        "average_capped_length,clipped_cases,clip_ratio\n"
    )
    lines = [header]
    for result in results:
        for budget_stat in result.budgets:
            clip_ratio = (
                budget_stat.clipped_cases / budget_stat.total_cases
                if budget_stat.total_cases
                else 0.0
            )
            lines.append(
                f'"{result.label}","{result.path}",{result.total_cases},'
                f"{result.average_raw_length:.6f},{budget_stat.budget},"
                f"{budget_stat.average_capped_length:.6f},{budget_stat.clipped_cases},"
                f"{clip_ratio:.6f}\n"
            )
    out_csv.write_text("".join(lines), encoding="utf-8")


def build_binned_distribution(
    length2count: Dict[int, int],
    budget: int,
    bin_width: int,
) -> List[Tuple[int, int, int]]:
    if bin_width <= 0:
        raise ValueError("distribution bin width must be positive")

    last_bin_start = ((max(budget, 1) - 1) // bin_width) * bin_width
    bin_counts = {
        bin_start: 0 for bin_start in range(0, budget + 1, bin_width)
        if bin_start <= last_bin_start
    }

    for length, count in length2count.items():
        bin_start = min((length // bin_width) * bin_width, last_bin_start)
        bin_counts[bin_start] = bin_counts.get(bin_start, 0) + count

    bins: List[Tuple[int, int, int]] = []
    for bin_start in sorted(bin_counts.keys()):
        bin_end = min(bin_start + bin_width, budget)
        bins.append((bin_start, bin_end, bin_counts[bin_start]))
    return bins


def format_interval_label(bin_start: int, bin_end: int) -> str:
    return f"[{bin_start}, {bin_end})"


def get_budget_stat(result: FileStats, budget: int) -> BudgetStats:
    for budget_stat in result.budgets:
        if budget_stat.budget == budget:
            return budget_stat
    raise ValueError(f"Budget {budget} not found for {result.label}")


def smooth_histogram_line(
    x_values: Sequence[float],
    y_values: Sequence[float],
    sigma_bins: float = 1.5,
    upsample_factor: int = 12,
) -> Tuple[np.ndarray, np.ndarray]:
    """使用单调三次插值让折线更圆滑，同时保持经过原始点。"""
    x_array = np.asarray(x_values, dtype=float)
    y_array = np.asarray(y_values, dtype=float)
    if x_array.size <= 1 or y_array.size <= 1:
        return x_array, y_array
    if x_array.size == 2:
        dense_x = np.linspace(
            x_array[0],
            x_array[-1],
            max(int(x_array.size * upsample_factor), x_array.size),
            dtype=float,
        )
        dense_y = np.interp(dense_x, x_array, y_array)
        return dense_x, dense_y

    h = np.diff(x_array)
    delta = np.diff(y_array) / h
    slopes = np.zeros_like(y_array)

    for idx in range(1, len(y_array) - 1):
        if delta[idx - 1] * delta[idx] <= 0:
            slopes[idx] = 0.0
        else:
            w1 = 2.0 * h[idx] + h[idx - 1]
            w2 = h[idx] + 2.0 * h[idx - 1]
            slopes[idx] = (w1 + w2) / (
                w1 / delta[idx - 1] + w2 / delta[idx]
            )

    def _endpoint_slope(first_h, second_h, first_delta, second_delta):
        slope = (
            ((2.0 * first_h + second_h) * first_delta) -
            (first_h * second_delta)
        ) / (first_h + second_h)
        if slope * first_delta <= 0:
            return 0.0
        if (first_delta * second_delta < 0) and (abs(slope) > abs(3.0 * first_delta)):
            return 3.0 * first_delta
        return slope

    slopes[0] = _endpoint_slope(h[0], h[1], delta[0], delta[1])
    slopes[-1] = _endpoint_slope(h[-1], h[-2], delta[-1], delta[-2])

    dense_x = np.linspace(
        x_array[0],
        x_array[-1],
        max(len(x_array) * upsample_factor, len(x_array)),
        dtype=float,
    )

    dense_y = np.empty_like(dense_x)
    segment_ids = np.searchsorted(x_array[1:], dense_x, side="right")
    segment_ids = np.clip(segment_ids, 0, len(h) - 1)

    for seg_idx in range(len(h)):
        mask = segment_ids == seg_idx
        if not np.any(mask):
            continue
        x_left = x_array[seg_idx]
        x_right = x_array[seg_idx + 1]
        y_left = y_array[seg_idx]
        y_right = y_array[seg_idx + 1]
        seg_h = h[seg_idx]
        t = (dense_x[mask] - x_left) / seg_h

        h00 = (2.0 * t ** 3) - (3.0 * t ** 2) + 1.0
        h10 = (t ** 3) - (2.0 * t ** 2) + t
        h01 = (-2.0 * t ** 3) + (3.0 * t ** 2)
        h11 = (t ** 3) - (t ** 2)

        dense_y[mask] = (
            h00 * y_left
            + h10 * seg_h * slopes[seg_idx]
            + h01 * y_right
            + h11 * seg_h * slopes[seg_idx + 1]
        )

    return dense_x, dense_y


def trim_zero_count_bins(
    binned_distribution: Sequence[Tuple[int, int, int]],
) -> List[Tuple[int, int, int]]:
    start = 0
    end = len(binned_distribution)
    while start < end and binned_distribution[start][2] == 0:
        start += 1
    while end > start and binned_distribution[end - 1][2] == 0:
        end -= 1
    return list(binned_distribution[start:end])


def compute_focus_cutoff(
    distributions: Sequence[Sequence[Tuple[int, int, int]]],
    bin_width: int,
    mass_threshold: float = 0.95,
    min_bins: int = 8,
) -> float:
    cutoffs: List[float] = []
    max_end = 0.0

    for binned_distribution in distributions:
        if not binned_distribution:
            continue
        counts = np.asarray(
            [count for _, _, count in binned_distribution], dtype=float)
        ends = np.asarray(
            [bin_end for _, bin_end, _ in binned_distribution], dtype=float)
        if not counts.size:
            continue

        max_end = max(max_end, float(ends[-1]))
        total = float(np.sum(counts))
        if total <= 0:
            continue

        cumulative = np.cumsum(counts) / total
        idx = int(np.searchsorted(cumulative, mass_threshold, side="left"))
        idx = min(idx, len(ends) - 1)
        idx = max(idx, min(min_bins - 1, len(ends) - 1))
        cutoffs.append(float(ends[idx]))

    if not cutoffs:
        return max_end

    focus = max(cutoffs)
    focus = max(focus, float(min_bins * bin_width))
    focus = min(focus, 0.8 * max_end if max_end > 0 else focus)
    focus = min(focus, max_end)
    return focus


def save_distribution_wide_csv(
    results: Sequence[FileStats],
    out_csv: Path,
    bin_width: int,
    distribution_budget: int,
) -> None:
    out_csv.parent.mkdir(parents=True, exist_ok=True)

    empty_bins = build_binned_distribution(
        {},
        budget=distribution_budget,
        bin_width=bin_width,
    )
    interval_labels = [
        format_interval_label(bin_start, bin_end)
        for bin_start, bin_end, _ in empty_bins
    ]

    with out_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["model"] + interval_labels)

        for result in results:
            binned_distribution = build_binned_distribution(
                result.length_frequencies.get(distribution_budget, {}),
                budget=distribution_budget,
                bin_width=bin_width,
            )
            counts = [frequency for _, _, frequency in binned_distribution]
            writer.writerow([result.label] + counts)


def plot_results(results: Sequence[FileStats], out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)

    plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
    plt.rcParams["axes.unicode_minus"] = False
    plt.rcParams["font.size"] = TICK_FONT_SIZE

    fig, ax = plt.subplots(figsize=(15, 10))
    colors = ["#4E79A7", "#E15759", "#59A14F", "#B07AA1", "#76B7B2"]
    markers = ["o", "s", "^", "D", "P"]

    x_ticks: List[int] = []
    for idx, result in enumerate(results):
        budgets = [item.budget for item in result.budgets]
        averages = [item.average_capped_length for item in result.budgets]
        x_ticks = budgets
        ax.plot(
            budgets,
            averages,
            label=result.label,
            color=colors[idx % len(colors)],
            marker=markers[idx % len(markers)],
            linewidth=2.4,
            markersize=7,
        )

    ax.set_title("Average Response Length Under Budget Caps",
                 fontsize=TITLE_FONT_SIZE, pad=12, fontweight="bold")
    ax.set_xlabel("Budget Cap", fontsize=AXIS_LABEL_FONT_SIZE,
                  fontweight="bold")
    ax.set_ylabel("Average Response Length (tokens)",
                  fontsize=AXIS_LABEL_FONT_SIZE, fontweight="bold")
    ax.set_xticks(x_ticks)
    ax.tick_params(axis="both", labelsize=TICK_FONT_SIZE)
    for tick in ax.get_xticklabels() + ax.get_yticklabels():
        tick.set_fontweight("bold")
    ax.grid(True, alpha=0.25)

    y_max = max(
        budget_stat.average_capped_length
        for result in results
        for budget_stat in result.budgets
    )
    ax.set_ylim(0, max(1000, math.ceil(y_max / 500.0) * 500 + 500))

    add_aligned_legend(ax)
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_distributions(
    results: Sequence[FileStats],
    distribution_budget: int,
    out_dir: Path,
    prefix: str,
    bin_width: int,
) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)

    plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
    plt.rcParams["axes.unicode_minus"] = False
    plt.rcParams["font.size"] = TICK_FONT_SIZE

    colors = [
        COLOR_BASELINE,
        COLOR_GSPO_LENGTH,
        COLOR_GSPO,
        "#C97B84",
        "#6C91BF",
    ]
    y_max = 0.0
    plot_xmax = min(10000, distribution_budget)
    fig, ax = plt.subplots(figsize=(12, 7))
    fig.patch.set_facecolor(FIG_BG_COLOR)
    ax.set_facecolor(BG_COLOR)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(2)
        ax.spines[side].set_color(SPINE_COLOR)
        ax.spines[side].set_visible(True)
    ax.grid(
        True,
        axis="both",
        alpha=0.65,
        color=GRID_COLOR,
        linewidth=2,
        linestyle="-",
        zorder=0,
    )
    ax.set_axisbelow(True)
    ax.tick_params(
        axis="both",
        labelcolor=TEXT_COLOR,
        labelsize=TICK_FONT_SIZE,
        length=5.2,
        width=2,
        color=SPINE_COLOR,
        pad=4.5,
    )
    for tick in ax.get_xticklabels() + ax.get_yticklabels():
        tick.set_fontweight("bold")
    ax.xaxis.set_major_locator(MaxNLocator(nbins=6))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=False))
    ax.xaxis.set_major_formatter(
        FuncFormatter(
            lambda value, _: "0" if abs(
                value) < 1e-9 else f"{value / 1000.0:.0f}k"
        )
    )

    plotted_any = False
    for idx, result in enumerate(results):
        length2count = result.length_frequencies.get(distribution_budget, {})
        if not length2count:
            continue

        binned_distribution = build_binned_distribution(
            length2count,
            budget=distribution_budget,
            bin_width=bin_width,
        )
        if binned_distribution and binned_distribution[-1][1] >= distribution_budget:
            binned_distribution = binned_distribution[:-1]
        binned_distribution = [
            (bin_start, bin_end, count)
            for bin_start, bin_end, count in binned_distribution
            if bin_start < plot_xmax
        ]
        if not binned_distribution:
            continue

        bin_starts = np.asarray(
            [bin_start for bin_start, _, _ in binned_distribution],
            dtype=float,
        )
        shares = np.asarray(
            [
                count / max(result.total_cases, 1) * 100.0
                for _, _, count in binned_distribution
            ],
            dtype=float,
        )
        color = colors[idx % len(colors)]
        bar_width = bin_width * 0.92
        line_x = np.asarray(
            [(bin_start + bin_end) / 2.0
             for bin_start, bin_end, _ in binned_distribution],
            dtype=float,
        )
        smooth_x, smooth_y = smooth_histogram_line(line_x, shares)

        plotted_any = True
        y_max = max(
            y_max,
            float(np.max(shares)) if shares.size else 0.0,
            float(np.max(smooth_y)) if smooth_y.size else 0.0,
        )

        ax.bar(
            bin_starts,
            shares,
            width=bar_width,
            align="edge",
            color=color,
            alpha=0.14,
            edgecolor="none",
            zorder=1,
        )
        ax.plot(
            smooth_x,
            smooth_y,
            color=color,
            linewidth=3.2,
            alpha=0.98,
            label=result.label,
            zorder=3,
            solid_joinstyle="round",
            solid_capstyle="round",
        )

    if not plotted_any:
        plt.close(fig)
        return

    ax.set_xlim(0, plot_xmax)
    ax.set_ylim(0, max(1.0, y_max * 1.12))
    ax.set_ylabel(
        "Percentage of Samples (%)",
        fontsize=AXIS_LABEL_FONT_SIZE,
        fontweight="bold",
        labelpad=10,
        color=TEXT_COLOR,
    )
    # ax.set_xlabel(
    #     "Capped Response Length (tokens)",
    #     fontsize=17,
    #     fontweight="bold",
    #     labelpad=10,
    #     color=TEXT_COLOR,
    # )
    # ax.text(
    #     0.98,
    #     0.97,
    #     f"budget={distribution_budget}, bin={bin_width}",
    #     transform=ax.transAxes,
    #     ha="right",
    #     va="top",
    #     fontsize=11,
    #     color=SPINE_COLOR,
    # )

    add_aligned_legend(ax)

    plt.subplots_adjust(top=0.90, bottom=0.12, left=0.10, right=0.97)
    out_path = out_dir / \
        f"{prefix}_distribution_budget_{distribution_budget}.pdf"
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description=(
            "Compute average response length under budget caps. "
            "For each budget B, use min(response_length, B)."
        )
    )
    parser.add_argument(
        "--input",
        dest="inputs",
        action="append",
        help="Input spec as PATH::LABEL. Can be passed multiple times.",
    )
    parser.add_argument(
        "--tokenizer",
        default=DEFAULT_TOKENIZER_PATH,
        help=f"Tokenizer path or model id. Default: {DEFAULT_TOKENIZER_PATH}",
    )
    parser.add_argument(
        "--budgets",
        nargs="+",
        type=int,
        default=DEFAULT_BUDGETS,
        help="Budget caps to evaluate. Default: 8192 16384 32768",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
        help="Batch size for tokenization. Default: 128",
    )
    parser.add_argument(
        "--out-dir",
        default=str(Path(__file__).resolve().parent / "plots"),
        help="Output directory for the figure and csv.",
    )
    parser.add_argument(
        "--prefix",
        default="capped_response_length",
        help="Output filename prefix.",
    )
    parser.add_argument(
        "--skip-distribution-plots",
        action="store_true",
        help="Do not generate distribution figures.",
    )
    parser.add_argument(
        "--distribution-bin-width",
        type=int,
        default=DEFAULT_DISTRIBUTION_BIN_WIDTH,
        help=f"Bin width for distribution wide csv and distribution plots. Default: {DEFAULT_DISTRIBUTION_BIN_WIDTH}",
    )
    return parser


def main() -> None:
    parser = build_parser()
    args = parser.parse_args()

    input_specs = args.inputs or [
        f"{path}::{label}" for path, label in DEFAULT_INPUTS
    ]
    parsed_inputs = [parse_input_spec(spec) for spec in input_specs]
    budgets = sorted(set(args.budgets))
    distribution_budget = max(budgets)

    print(f"Loading tokenizer from: {args.tokenizer}")
    tokenizer = load_tokenizer(args.tokenizer)

    results: List[FileStats] = []
    for path, label in parsed_inputs:
        print(f"Processing {label}: {path}")
        result = compute_file_stats(
            path=path,
            label=label,
            tokenizer=tokenizer,
            budgets=budgets,
            batch_size=args.batch_size,
        )
        results.append(result)
        print(
            f"  total_cases={result.total_cases}, "
            f"avg_raw_length={result.average_raw_length:.2f}"
        )
        for budget_stat in result.budgets:
            clip_ratio = (
                budget_stat.clipped_cases / budget_stat.total_cases
                if budget_stat.total_cases
                else 0.0
            )
            print(
                f"  budget={budget_stat.budget:<5d} "
                f"avg_capped={budget_stat.average_capped_length:.2f} "
                f"clipped={budget_stat.clipped_cases} "
                f"clip_ratio={clip_ratio:.3%}"
            )
            if budget_stat.budget == distribution_budget:
                unique_lengths = len(
                    result.length_frequencies.get(budget_stat.budget, {}))
                num_bins = len(
                    build_binned_distribution(
                        result.length_frequencies.get(budget_stat.budget, {}),
                        budget=budget_stat.budget,
                        bin_width=args.distribution_bin_width,
                    )
                )
                print(
                    f"           distribution_budget={distribution_budget}, "
                    f"unique_lengths={unique_lengths}, bins={num_bins}")

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_png = out_dir / f"{args.prefix}.pdf"
    out_csv = out_dir / f"{args.prefix}.csv"
    out_distribution_wide_csv = out_dir / \
        f"{args.prefix}_distribution_wide.csv"

    save_csv(results, out_csv)
    save_distribution_wide_csv(
        results,
        out_distribution_wide_csv,
        bin_width=args.distribution_bin_width,
        distribution_budget=distribution_budget,
    )
    plot_results(results, out_png)
    if not args.skip_distribution_plots:
        plot_distributions(
            results,
            distribution_budget=distribution_budget,
            out_dir=out_dir,
            prefix=args.prefix,
            bin_width=args.distribution_bin_width,
        )

    print(f"Saved figure to: {out_png}")
    print(f"Saved table to: {out_csv}")
    print(f"Saved wide distribution csv to: {out_distribution_wide_csv}")
    print(
        f"Distribution is only computed for the largest budget: {distribution_budget}")


if __name__ == "__main__":
    main()
