#!/usr/bin/env python3
import argparse
import json
import re
from pathlib import Path


STEP_RE = re.compile(r"step_(\d+)_traindata\.jsonl$")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Compute per-step frequency of response lengths in a target range."
    )
    parser.add_argument(
        "--training-data-dir",
        type=Path,
        required=True,
        help="Directory containing step_*_traindata.jsonl files.",
    )
    parser.add_argument(
        "--low",
        type=float,
        default=6000.0,
        help="Lower bound (inclusive). Default: 6000.",
    )
    parser.add_argument(
        "--high",
        type=float,
        default=8000.0,
        help="Upper bound (inclusive). Default: 8000.",
    )
    parser.add_argument(
        "--save-csv",
        type=Path,
        default=None,
        help="Optional output CSV path.",
    )
    return parser.parse_args()


def iter_step_files(training_data_dir: Path):
    files = []
    for p in training_data_dir.glob("step_*_traindata.jsonl"):
        m = STEP_RE.search(p.name)
        if m:
            files.append((int(m.group(1)), p))
    files.sort(key=lambda x: x[0])
    return files


def to_float_list(value):
    if value is None:
        return []
    if isinstance(value, list):
        values = value
    else:
        values = [value]

    out = []
    for v in values:
        try:
            out.append(float(v))
        except (TypeError, ValueError):
            continue
    return out


def in_range(x: float, low: float, high: float) -> bool:
    return low <= x <= high


def compute_step_stats(file_path: Path, low: float, high: float):
    sample_total = 0
    sample_in_range = 0
    item_total = 0
    item_has_in_range = 0

    with file_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                item = json.loads(line)
            except json.JSONDecodeError:
                continue

            lengths = to_float_list(item.get("response_length", []))
            if not lengths:
                continue

            item_total += 1
            hit = False
            for length in lengths:
                sample_total += 1
                if in_range(length, low, high):
                    sample_in_range += 1
                    hit = True
            if hit:
                item_has_in_range += 1

    sample_freq = (sample_in_range / sample_total) if sample_total else 0.0
    item_freq = (item_has_in_range / item_total) if item_total else 0.0
    return {
        "sample_total": sample_total,
        "sample_in_range": sample_in_range,
        "sample_freq": sample_freq,
        "item_total": item_total,
        "item_has_in_range": item_has_in_range,
        "item_freq": item_freq,
    }


def main():
    args = parse_args()
    training_data_dir = args.training_data_dir

    if not training_data_dir.exists():
        raise FileNotFoundError(f"Directory not found: {training_data_dir}")
    if not training_data_dir.is_dir():
        raise NotADirectoryError(f"Not a directory: {training_data_dir}")

    step_files = iter_step_files(training_data_dir)
    if not step_files:
        raise FileNotFoundError(
            f"No step_*_traindata.jsonl files found in: {training_data_dir}"
        )

    header = (
        "step\tsample_in_range\tsample_total\tsample_freq"
        "\titem_has_in_range\titem_total\titem_freq"
    )
    print(header)

    rows = []
    for step, file_path in step_files:
        stats = compute_step_stats(file_path, args.low, args.high)
        row = {
            "step": step,
            **stats,
        }
        rows.append(row)
        print(
            f"{step}\t{stats['sample_in_range']}\t{stats['sample_total']}"
            f"\t{stats['sample_freq']:.6f}\t{stats['item_has_in_range']}"
            f"\t{stats['item_total']}\t{stats['item_freq']:.6f}"
        )

    if args.save_csv is not None:
        args.save_csv.parent.mkdir(parents=True, exist_ok=True)
        with args.save_csv.open("w", encoding="utf-8") as f:
            f.write(
                "step,sample_in_range,sample_total,sample_freq,"
                "item_has_in_range,item_total,item_freq\n"
            )
            for row in rows:
                f.write(
                    f"{row['step']},{row['sample_in_range']},{row['sample_total']},"
                    f"{row['sample_freq']:.8f},{row['item_has_in_range']},"
                    f"{row['item_total']},{row['item_freq']:.8f}\n"
                )


if __name__ == "__main__":
    main()
