#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import ast
import argparse
import sys
import textwrap

def count_loops_in_code(code_str: str) -> int:
    """
    先去除统一的缩进，然后解析一段 Python 代码字符串，
    统计其中 for/while 节点的数量。
    如果语法解析失败，则抛出 SyntaxError。
    """
    clean_code = textwrap.dedent(code_str).strip()
    if not clean_code:
        return 0
    # 直接在这里尝试 parse，解析失败会抛出
    tree = ast.parse(clean_code)
    return sum(1 for node in ast.walk(tree) if isinstance(node, (ast.For, ast.While)))

def analyze_file(input_path: str):
    total_loops = 0
    total_programs = 0
    json_errors = 0
    ast_errors = 0

    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                json_errors += 1
                continue

            total_programs += 1
            code = obj.get('program', '')
            try:
                loop_count = count_loops_in_code(code)
            except (SyntaxError, ValueError):
                # AST 解析失败
                ast_errors += 1
                continue

            total_loops += loop_count

    if total_programs == 0:
        print("没有有效的程序可以分析。", file=sys.stderr)
        return

    avg_loops = total_loops / (total_programs - ast_errors)
    print(f"样本总数: {total_programs}")
    print(f"JSON 解析失败行数: {json_errors}")
    print(f"AST 解析失败数量: {ast_errors}")
    print(f"成功解析并统计循环的程序数量: {total_programs - ast_errors}")
    print(f"循环总数: {total_loops}")
    print(f"平均每个程序的循环数: {avg_loops:.3f}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="统计 JSONL 文件中每个程序的 for/while 循环数，并计算平均值，同时输出 JSON 解析失败行数和 AST 解析失败数量。"
    )
    parser.add_argument(
        'input_file',
        help="待分析的 JSONL 文件路径，每行应包含一个包含 program 字段的 JSON 对象。"
    )
    args = parser.parse_args()
    analyze_file(args.input_file)
