#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
把任意 .jsonl 按行数均分或按给定比例分割成 N 份，保证三台机器跑完结果完全一致。
用法:
    均分 : python split_jsonl.py --input xxx.jsonl --parts 10
    比例 : python split_jsonl.py --input xxx.jsonl --ratio 1,2,3,4
"""
import argparse
import os


def parse_ratio(ratio_str: str) -> list[int]:
    """把 '1,2,3' 转成 [1,2,3]"""
    try:
        ratios = [int(x.strip()) for x in ratio_str.split(",")]
        if any(r <= 0 for r in ratios):
            raise ValueError
        return ratios
    except Exception:
        raise argparse.ArgumentTypeError("ratio 必须是逗号分隔的正整数，如 1,2,3")


def split_jsonl(input_path: str, parts: int | None, ratio: list[int] | None):
    if parts is None and ratio is None:
        raise ValueError("必须指定 --parts 或 --ratio 之一")
    if parts is not None and ratio is not None:
        raise ValueError("--parts 与 --ratio 不能同时给出")

    # 1. 先数总行数
    print("Counting lines...")
    with open(input_path, "rb") as f:
        total = sum(1 for _ in f)
    print(f"Total lines: {total}")

    # 2. 计算每份多少行
    if ratio:  # 按比例
        parts = len(ratio)
        sum_ratio = sum(ratio)
        base = total // sum_ratio
        remainder = total % sum_ratio
        # 先按 base*ratio 分配，再把 remainder 从前到后各加 1
        chunks = [base * r for r in ratio]
        chunks[0] += remainder
    else:  # 均分
        base = total // parts
        remainder = total % parts
        chunks = [base + 1 if i < remainder else base for i in range(parts)]
    print("Lines per part:", chunks)

    # 3. 按行写文件
    outfile_pattern = input_path.replace(".jsonl", ".part-{idx:02d}.jsonl")
    writers = [
        open(outfile_pattern.format(idx=idx), "w", encoding="utf-8")
        for idx in range(parts)
    ]
    with open(input_path, "r", encoding="utf-8") as fin:
        written = 0
        for idx, size in enumerate(chunks):
            for _ in range(size):
                line = fin.readline()
                if not line:
                    break
                writers[idx].write(line)
                written += 1
    for fh in writers:
        fh.close()
    print(f"Split finished! {written} lines written into {parts} parts.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evenly or proportionally split a .jsonl into N parts."
    )
    parser.add_argument("--input", required=True, help="原始 .jsonl 路径")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        "--parts", type=int, help="要切成的份数（均分）"
    )
    group.add_argument(
        "--ratio", type=parse_ratio, help="按比例分割，如 1,2,3"
    )
    args = parser.parse_args()
    split_jsonl(args.input, args.parts, args.ratio)