import argparse
from utils import load_single_dataset
from datasets import Dataset, DatasetDict

THRESH = 0.5  # 如需含边界，改成 >= / <= 的逻辑

def categorize(batch):
    # 期待：batch["scores"] 是 list[list[float]]
    # 处理异常与空列表
    bucket = []
    for scores in batch["scores"]:
        if not isinstance(scores, list) or len(scores) == 0:
            bucket.append("true_and_false")  # 或者定义成单独的 "empty" 桶
            continue

        # 过滤非数值，避免 all() 内部抛异常
        vals = []
        for s in scores:
            try:
                vals.append(float(s))
            except (TypeError, ValueError):
                pass

        if len(vals) == 0:
            bucket.append("true_and_false")
            continue

        all_true = all(v > THRESH for v in vals)   # 如需含边界：>=
        all_false = all(v < THRESH for v in vals)  # 如需含边界：<=

        if all_true:
            bucket.append("all_true")
        elif all_false:
            bucket.append("all_false")
        else:
            bucket.append("true_and_false")

    return {"bucket": bucket}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data",   type=str, required=True)
    parser.add_argument("--begin",  type=int, required=False, default=None)
    parser.add_argument("--end",    type=int, required=False, default=None)
    parser.add_argument("--save_file", type=str, required=True)
    parser.add_argument("--num_proc", type=int, default=6)         # 更保守的并发
    parser.add_argument("--batch_size", type=int, default=1024)    # 可按内存调小
    parser.add_argument("--use_cache", action="store_true")        # 手动控制是否用缓存
    args = parser.parse_args()

    ds = load_single_dataset(args.data)  # 应返回 datasets.Dataset

    # —— 安全切片：用 select 而不是 Python 切片 ——
    if args.begin is not None or args.end is not None:
        n = len(ds)
        begin = args.begin if args.begin is not None else 0
        end = args.end if args.end is not None else n
        if begin < 0 or end < 0 or begin > n or end > n or begin > end:
            raise ValueError(f"Invalid slice range: begin={begin}, end={end}, length={n}")
        ds = ds.select(range(begin, end))

    # 一次 map：保守并发与 batch_size，必要时禁用缓存
    ds = ds.map(
        categorize,
        batched=True,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
        load_from_cache_file=args.use_cache,
        desc="categorize",
    )

    # 基于标签拿索引（这一步会把列拉到内存中，通常可接受；超大数据可考虑走 filter 方案）
    buckets = ds["bucket"]
    idx_all_true       = [i for i, b in enumerate(buckets) if b == "all_true"]
    idx_all_false      = [i for i, b in enumerate(buckets) if b == "all_false"]
    idx_true_and_false = [i for i, b in enumerate(buckets) if b == "true_and_false"]

    # 去掉临时列再保存
    ds_no_tmp = ds.remove_columns(["bucket"])

    out = DatasetDict({
        "all_true":       ds_no_tmp.select(idx_all_true),
        "all_false":      ds_no_tmp.select(idx_all_false),
        "true_and_false": ds_no_tmp.select(idx_true_and_false),
    })

    out.save_to_disk(args.save_file)

if __name__ == "__main__":
    main()

    
"""

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/dsfilter_3_save_datasetdict.py \
    --data      ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_scored.json \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_classify
    
    
"""
