import os
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import ast
import json
import math
import random
from tqdm import tqdm
from code_analysis import is_followup_more_complex
def compute_depth_width_metrics_graph_based(dep_graph_old, dep_graph_new):
    """
    使用新定义：
    - 深度 = 最长依赖链长度（DFS）
    - 宽度 = 被依赖次数最多的变量的 in-degree
    """
    def compute_max_variable_in_degree(dep_graph):
        from collections import defaultdict
        in_degree_count = defaultdict(int)
        for src, targets in dep_graph.items():
            for target in targets:
                in_degree_count[target] += 1
        return max(in_degree_count.values(), default=0)

    def longest_path_length(dep_graph):
        memo = {}
        visited = set()

        def dfs(node):
            if node in memo:
                return memo[node]
            visited.add(node)
            max_depth = 0
            for neighbor in dep_graph.get(node, []):
                if neighbor not in visited:
                    max_depth = max(max_depth, dfs(neighbor))
            visited.remove(node)
            memo[node] = max_depth + 1
            return memo[node]

        return max((dfs(n) for n in dep_graph), default=0)

    # --- Compute depth ---
    depth_old = longest_path_length(dep_graph_old) if dep_graph_old else 0
    depth_new = longest_path_length(dep_graph_new) if dep_graph_new else 0
    depth_increase = max(depth_new - depth_old, 0)

    # --- Compute new width logic: max in-degree ---
    width_old = compute_max_variable_in_degree(dep_graph_old)
    width_new = compute_max_variable_in_degree(dep_graph_new)
    width_increase = max(width_new - width_old, 0)

    total = width_increase + depth_increase if (width_increase + depth_increase) > 0 else 1
    width_coef = round(width_increase / total, 2)
    depth_coef = round(depth_increase / total, 2)

    return {
        "depth_increase": depth_increase,
        "width_increase": width_increase,
        "depth_old": depth_old,
        "depth_new": depth_new,
        "width_old": width_old,
        "width_new": width_new,
        "depth_coef": depth_coef,
        "width_coef": width_coef,
    }
def find_dependencies_from_source(source_code):
    """
    静态分析给定源代码并提取变量依赖关系。
    返回依赖关系字典：{变量名: [依赖变量名列表]}
    支持依赖类型：assign, augassign, call, subscript, attribute, control, loop
    """
    dependencies = defaultdict(set)

    try:
        tree = ast.parse(source_code)

        for node in ast.walk(tree):

            # 1. 普通赋值
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        target_var = target.id
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target_var].add((child.id, "assign"))
                            elif isinstance(child, ast.Subscript):
                                if isinstance(child.value, ast.Name):
                                    base = child.value.id
                                    for s in ast.walk(child.slice):
                                        if isinstance(s, ast.Name):
                                            dependencies[base].add((s.id, "subscript"))
                                        elif isinstance(s, ast.Constant):
                                            dependencies[base].add((str(s.value), "subscript"))
                            elif isinstance(child, ast.Attribute):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add((child.value.id, "attribute"))
                            elif isinstance(child, ast.Call):
                                if isinstance(child.func, ast.Name):
                                    dependencies[target_var].add((child.func.id, "call"))
                                for arg in child.args:
                                    for arg_node in ast.walk(arg):
                                        if isinstance(arg_node, ast.Name):
                                            dependencies[target_var].add((arg_node.id, "call"))

                    elif isinstance(target, ast.Subscript):  # x[b] = c
                        if isinstance(target.value, ast.Name):
                            x_var = target.value.id
                            for s in ast.walk(target.slice):
                                if isinstance(s, ast.Name):
                                    dependencies[x_var].add((s.id, "subscript"))
                                elif isinstance(s, ast.Constant):
                                    dependencies[x_var].add((str(s.value), "subscript"))
                            for child in ast.walk(node.value):
                                if isinstance(child, ast.Name):
                                    dependencies[x_var].add((child.id, "assign"))

            # 2. 增量赋值 a += b
            elif isinstance(node, ast.AugAssign):
                if isinstance(node.target, ast.Name):
                    target_var = node.target.id
                    dependencies[target_var].add((target_var, "augassign"))
                    for child in ast.walk(node.value):
                        if isinstance(child, ast.Name):
                            dependencies[target_var].add((child.id, "augassign"))

            # 3. 控制结构（if 条件 -> 被控变量）
            elif isinstance(node, ast.If):
                condition_vars = {child.id for child in ast.walk(node.test) if isinstance(child, ast.Name)}
                for stmt in node.body + node.orelse:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for cond in condition_vars:
                                        dependencies[target.id].add((cond, "control"))

            # 4. 循环依赖（for x in iterable -> x依赖iterable）
            elif isinstance(node, ast.For):
                loop_vars = {child.id for child in ast.walk(node.iter) if isinstance(child, ast.Name)}
                for stmt in node.body:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for lv in loop_vars:
                                        dependencies[target.id].add((lv, "loop"))

    except Exception as e:
        print(f"Error analyzing source code:\n{source_code}")
        print(f"Exception: {e}")

    # 格式化输出：只保留被依赖变量（不含依赖类型）
    dep_dict = {k: sorted(set(src for (src, _) in v)) for k, v in dependencies.items()}
    return dep_dict
def rich_escape(text):
    return str(text).replace('[', '\\[').replace(']', '\\]')
def analyze_and_export_combined(dict_list, output_dir="analysis_output", percentage_thresholds=[50, 80, 90]):
    """
    综合分析两个字段集合：
    1. 连续变量做累计百分比分析 + 累计曲线图
    2. 离散变量做频次统计 + 频次直方图
    输出：summary_cumsum.csv 表 + 累计图 & 频次图
    """
    os.makedirs(output_dir, exist_ok=True)

    # 类型一：连续变量
    continuous_fields = ["vocabulary", "length", "volume", "difficulty", "effort"]

    # 类型二：离散变量
    discrete_fields = ["depth_increase", "width_increase", "depth_old", "depth_new", "width_old", "width_new"]

    # 初始化收集器
    field_value_counter = defaultdict(Counter)

    for d in dict_list:
        for k, v in d.items():
            if k in continuous_fields + discrete_fields:
                field_value_counter[k][v] += 1

    # === 连续变量分析（累计百分比）===
    results = {}
    plt.figure(figsize=(10, 6))
    for field in continuous_fields:
        counter = field_value_counter[field]
        df = pd.DataFrame(counter.items(), columns=["value", "count"])
        df = df.sort_values(by="value").reset_index(drop=True)
        df["cumsum"] = df["count"].cumsum()
        df["cum_percentage"] = df["cumsum"] / df["count"].sum() * 100

        # 累计图
        plt.plot(df["value"], df["cum_percentage"], label=field, marker='o')

        # 提取百分比点
        # 记录达到每个百分比时的value（加防护）
        results[field] = {}
        for threshold in percentage_thresholds:
            above_thresh = df[df["cum_percentage"] >= threshold]
            if not above_thresh.empty:
                found = above_thresh.iloc[0]["value"]
            else:
                found = None  # 或用 float("nan") 更合适统计
            results[field][threshold] = found

    # 保存表格
    result_table = pd.DataFrame(results).T
    result_table.columns = [f"{p}%" for p in percentage_thresholds]
    result_table.to_csv(os.path.join(output_dir, "summary_cumsum.csv"))
    print(f"\n✅ 连续变量结果表已保存: summary_cumsum.csv")

    # 保存图
    plt.xlabel("Value")
    plt.ylabel("Cumulative Percentage (%)")
    plt.title("Cumulative Distribution of Continuous Fields")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "summary_cumsum_plot.png"))
    plt.close()
    print(f"✅ 连续变量累计图已保存: summary_cumsum_plot.png")

    # === 离散变量分析（频次直方图）===
    n = len(discrete_fields)
    fig, axs = plt.subplots(n, 1, figsize=(10, 3 * n), constrained_layout=True)

    if n == 1:
        axs = [axs]

    for i, field in enumerate(discrete_fields):
        counter = field_value_counter[field]
        df = pd.DataFrame(counter.items(), columns=["value", "count"]).sort_values(by="value")
        axs[i].bar(df["value"].astype(str), df["count"])
        axs[i].set_title(field)
        axs[i].set_ylabel("Frequency")
        axs[i].tick_params(axis='x', rotation=45)

    plt.savefig(os.path.join(output_dir, "summary_frequency_plot.png"))
    plt.close()
    print(f"✅ 离散变量频次图已保存: summary_frequency_plot.png")

    # 可选：输出控制台表格
    print("\n📋 连续变量各百分比值:")
    print(result_table)
def json2csv(data):
    records = []
    for sample in data["data"]:
        origin = sample["origin_data"]
        extend = sample.get("extend_data_1", {})

        record = {
            "id": origin["source_id"],
            "query": origin["question"],
            "img_path": origin["image"],
            "possible_answers": origin["golden_answer"],
            "original_code": origin["program"],
            "original_result": origin["program_answer"],
            "followup_question": extend.get("question", ""),
            "followup_code": extend.get("program", ""),
            "followup_result": extend.get("program_answer", ""),
        }

        # 推测 expansion_index 的反向信息（可选）
        extend_method = extend.get("static analysis", {}).get("extend_method", None)
        record["expansion_index"] = extend_method
        records.append(record)

    return pd.DataFrame(records)

def count_image(sample,compare_part2="extend_data_2"):
    image_field = sample.get(compare_part2, {}).get("image", "")
    if isinstance(image_field, list):
        num_images = len(image_field)
    elif isinstance(image_field, str) and image_field.strip().startswith("["):
        image_list = ast.literal_eval(image_field)
        num_images = len(image_list)
    elif isinstance(image_field, str):
        num_images = 1 if image_field.strip() else 0
    else:
        num_images = 0
    return num_images
def analyze_image_field(samples,compare_part2="extend_data_1"):
    """
    统计 image 字段中图像数量（单图/双图/多图/无图）的分布情况。
    :param samples: JSON 中的 data 样本列表（list of dict）
    :return: Counter 对象（key 为类别，value 为数量）
    """
    print(f"统计该字段图像： {compare_part2}")
    image_count_type_counter = Counter()
    list_number = 0
    for sample in samples:
        image_field = sample.get(compare_part2, {}).get("image", "")
        if isinstance(image_field, list):
            num_images = len(image_field)
            list_number += 1
        elif isinstance(image_field, str) and image_field.strip().startswith("["):
            image_list = ast.literal_eval(image_field)
            num_images = len(image_list)
            list_number += 1
        elif isinstance(image_field, str):
            num_images = 1 if image_field.strip() else 0
        else:
            num_images = 0
        if num_images == 1:
            image_count_type_counter["单图"] += 1
        elif num_images == 2:
            image_count_type_counter["双图"] += 1
        elif num_images >= 3:
            image_count_type_counter["多图"] += 1
        else:
            image_count_type_counter["无图"] += 1
    print(f"列表数量：{list_number}")
    return image_count_type_counter
def determine_difficulty(effort_o, effort_f):
    if effort_f<=4000:
        return "easy"
    elif effort_f>4000 and effort_f<=6000:
        return "medium"
    else:
        return "hard"
def extend_method_analysis(depth_inc,width_inc):
    if depth_inc > width_inc:
        return "depth"
    elif depth_inc < width_inc:
        return "width"
    else:
        return "balanced"

def analyze_stage_summary_combined(json_path, stages=None):
    """
    汇总统计所有阶段的静态分析标签，累计全局结果。
    :param samples: list of dict，已处理后的数据（完整 data）
    :param stages: list of阶段对，如 [("origin_data", "extend_data_1"), ...]
    """
    with open(json_path, "r", encoding="utf-8") as f:
        samples = json.load(f)["data"]

    if stages is None:
        stages = [
            ("origin_data", "extend_data_1"),
            ("extend_data_1", "extend_data_2"),
            ("extend_data_2", "extend_data_3")
        ]

    difficulty_counter = Counter()
    extend_method_counter = Counter()
    image_type_counter = Counter()
    data_source_counter = Counter()
    total = 0
    null_method_count = 0
    for sample in samples:
        data_source = sample.get("origin_data", {}).get("data_source", "未标注")
        if data_source in ["OKVQA2", "OKVQA3"]:
            data_source = "OKVQA"
        elif data_source == "seedbench2":
            data_source = "SeedBench2"
        elif data_source == "MMBench2":
            data_source = "MMBench"
        data_source_counter[data_source] += 1

        for _, compare_part2 in stages:
            part2_data = sample.get(compare_part2, {})
            analysis = part2_data.get("static analysis", {})
            if analysis:
                difficulty = analysis.get("difficulty", None)
                method = analysis.get("extend_method", None)
                if not method:
                    print(f"⚠️ Sample {sample['origin_data']['source_id']} has no method in {compare_part2}.")
                    null_method_count += 1
                    method = "UNKNOWN"
                difficulty_counter[difficulty] += 1
                extend_method_counter[method] += 1
                total += 1

            image_count = part2_data.get("image_count", 0)
            if image_count == 0:
                print(f"⚠️ Sample {sample['origin_data']['source_id']} has no image in {compare_part2}.")
                image_type_counter["无图"] += 1
            elif image_count == 1:
                image_type_counter["单图"] += 1
            elif image_count == 2:
                image_type_counter["双图"] += 1
            else:
                image_type_counter["多图"] += 1

    print("\n📊 📦 累计所有阶段统计：")
    print(f"\n 分析数据集：{json_path}")
    print("\n📚 数据来源统计：")
    for source, count in data_source_counter.items():
        print(f"  {source}: {count} 条，占比 {count / sum(data_source_counter.values()):.1%}")

    print("\n🖼️ 图像字段统计（按拓展总次数计）：")
    for k, v in image_type_counter.items():
        print(f"  {k}: {v} 条，占比 {v / total * 100:.1f}%")

    print("\n📊 拓展难度分布：")
    for level, count in difficulty_counter.items():
        print(f"  {level.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

    print("\n📊 拓展方法分布：")
    for method, count in extend_method_counter.items():
        label = str(method).upper()
        print(f"  {label:<10}: {count} 条，占比 {count / total * 100:.1f}%")
    print(f"❗ 其中 UNKNOWN（原为 None 或缺失）数量：{null_method_count} 条，占比 {null_method_count / total * 100:.1f}%")
    print(f"\n✅ 总拓展次数（3阶段合并）：{total}")
def code_analysis_from_json(json_path="your_data.json", compare_part1="origin_data", compare_part2="extend_data_1", output_path="filtered_data.json"):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)["data"]
    # 先清洗空 program 样本
    clean_samples = []
    for sample in data:
        code1 = sample.get(compare_part1, {}).get("program", "")
        code2 = sample.get(compare_part2, {}).get("program", "")
        if code1.strip() and code2.strip():
            clean_samples.append(sample)

    print(f"✅ 可用样本数：{len(clean_samples)} / {len(data)}")

    # 划分训练集与测试集
    random.seed(42)
    random.shuffle(clean_samples)
    split_idx = int(len(clean_samples) * 0.7)
    train_samples = clean_samples[:split_idx]
    test_samples = clean_samples[split_idx:]

    # 分别分析两个数据集
    def analyze(samples, split_name):

        valid_samples = []
        extend_index = []
        difficulty_counter = Counter()
        extend_method_counter = Counter()
        data_source_counter = Counter()

        for i, sample in enumerate(tqdm(samples, desc=f"Analyzing {split_name} set")):
            try:
                code1 = sample.get(compare_part1, {}).get("program", "")
                code2 = sample.get(compare_part2, {}).get("program", "")
                data_source = sample.get("origin_data", {}).get("data_source", "未标注")

                # 标准化数据来源
                if data_source == "OKVQA2":
                    sample["origin_data"]["data_source"] = "OKVQA"
                elif data_source == "OKVQA3":
                    sample["origin_data"]["data_source"] = "OKVQA"
                elif data_source == "seedbench2":
                    sample["origin_data"]["data_source"] = "SeedBench2"
                elif data_source == "MMBench2":
                    sample["origin_data"]["data_source"] = "MMBench"
                data_source = sample.get("origin_data", {}).get("data_source", "未标注")

                data_source_counter[data_source] += 1

                graph1 = find_dependencies_from_source(code1)
                graph2 = find_dependencies_from_source(code2)

                if graph1 and graph2:
                    num_count=count_image(sample,compare_part2)
                    metrics = is_followup_more_complex(code1, code2)[1]
                    extend_index.append(metrics["halstead_f"])
                    difficulty = determine_difficulty(metrics["effort_o"], metrics["effort_f"])

                    extend_metrics = compute_depth_width_metrics_graph_based(graph1, graph2)
                    extend_index.append(extend_metrics)

                    extend_method = extend_method_analysis(extend_metrics["depth_coef"], extend_metrics["width_coef"])

                    sample[compare_part2]["static analysis"] = {
                        "difficulty": difficulty,
                        "extend_method": extend_method
                    }
                    sample[compare_part2]["image_count"] = num_count
                    extend_method_counter[extend_method] += 1
                    difficulty_counter[difficulty] += 1
                    valid_samples.append(sample)

            except Exception as e:
                print(f"⚠️ Error in sample {i}: {e}")
                continue
        print(f"\n🔍 分析 {compare_part2} 拓展数据...")
        total = sum(difficulty_counter.values())
        print(f"\n📊 [{split_name}] 有效样本数：{total}")

        print("\n📚 数据来源统计：")
        for source, count in data_source_counter.items():
            print(f"  {source}: {count} 条，占比 {count / total:.1%}")

        # 图像字段统计
        image_type_counter = analyze_image_field(valid_samples, compare_part2)
        print("\n🖼️ 图像字段统计：")
        for k, v in image_type_counter.items():
            print(f"  {k}: {v} 条，占比 {v / total:.1%}")

        print("\n📊 拓展难度分布：")
        for level, count in difficulty_counter.items():
            print(f"  {level.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

        print("\n📊 拓展方法分布：")
        for method, count in extend_method_counter.items():
            print(f"  {method.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

        return valid_samples, extend_index

    train_valid, train_index = analyze(train_samples, "Train")
    test_valid, test_index = analyze(test_samples, "Test")

    # 保存处理后的数据
    with open(output_path.replace(".json", "_train.json"), "w", encoding="utf-8") as f_train:
        json.dump({"data": train_valid}, f_train, ensure_ascii=False, indent=2)

    with open(output_path.replace(".json", "_test.json"), "w", encoding="utf-8") as f_test:
        json.dump({"data": test_valid}, f_test, ensure_ascii=False, indent=2)

    print(f"\n✅ 已保存数据集：训练集 {len(train_valid)} 条，测试集 {len(test_valid)} 条")

    # 可选分析图表合并输出
    analyze_and_export_combined(train_index + test_index, "./analysis_output")
def code_analysis_from_json_multi(json_path="your_data.json", output_path="filtered_data.json"):
    from collections import Counter
    import json, random
    from tqdm import tqdm

    stages = [
        ("origin_data", "extend_data_1"),
        ("extend_data_1", "extend_data_2"),
        ("extend_data_2", "extend_data_3")
    ]

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)["data"]

    # 预清洗所有样本：确保每一阶段都有有效代码
    clean_samples = []
    for sample in data:
        valid = True
        for part1, part2 in stages:
            code1 = sample.get(part1, {}).get("program", "")
            code2 = sample.get(part2, {}).get("program", "")
            if not code1.strip() or not code2.strip():
                valid = False
                break
        if valid:
            clean_samples.append(sample)

    print(f"✅ 可用样本数（所有阶段都完整）：{len(clean_samples)} / {len(data)}")

    # 划分数据集
    random.seed(42)
    random.shuffle(clean_samples)
    split_idx = int(len(clean_samples) * 0.7)
    train_samples = clean_samples[:split_idx]
    test_samples = clean_samples[split_idx:]

    def analyze(samples, split_name):
        stage_analysis_results = {}

        for compare_part1, compare_part2 in stages:
            print(f"\n📌 开始分析阶段：{compare_part1} → {compare_part2} ...")

            valid_samples = []
            extend_index = []
            difficulty_counter = Counter()
            extend_method_counter = Counter()
            data_source_counter = Counter()

            for i, sample in enumerate(tqdm(samples, desc=f"{split_name} | {compare_part2}")):
                try:
                    code1 = sample.get(compare_part1, {}).get("program", "")
                    code2 = sample.get(compare_part2, {}).get("program", "")
                    data_source = sample.get("origin_data", {}).get("data_source", "未标注")

                    # 标准化数据来源
                    if data_source in ["OKVQA2", "OKVQA3"]:
                        sample["origin_data"]["data_source"] = "OKVQA"
                    elif data_source == "seedbench2":
                        sample["origin_data"]["data_source"] = "SeedBench2"
                    elif data_source == "MMBench2":
                        sample["origin_data"]["data_source"] = "MMBench"
                    data_source = sample["origin_data"]["data_source"]
                    data_source_counter[data_source] += 1

                    graph1 = find_dependencies_from_source(code1)
                    graph2 = find_dependencies_from_source(code2)

                    if graph1 and graph2:
                        num_count = count_image(sample, compare_part2)
                        metrics = is_followup_more_complex(code1, code2)[1]
                        extend_index.append(metrics["halstead_f"])

                        difficulty = determine_difficulty(metrics["effort_o"], metrics["effort_f"])
                        extend_metrics = compute_depth_width_metrics_graph_based(graph1, graph2)
                        extend_index.append(extend_metrics)
                        extend_method = extend_method_analysis(extend_metrics["depth_coef"], extend_metrics["width_coef"])

                        sample[compare_part2]["static analysis"] = {
                            "difficulty": difficulty,
                            "extend_method": extend_method
                        }
                        if "image_count" not in sample[compare_part2]:
                            sample[compare_part2]["image_count"] = num_count

                        extend_method_counter[extend_method] += 1
                        difficulty_counter[difficulty] += 1
                        valid_samples.append(sample)

                except Exception as e:
                    print(f"⚠️ Error at stage {compare_part1}->{compare_part2}, sample {i}: {e}")
                    continue

            # 输出分析信息
            total = sum(difficulty_counter.values())
            print(f"\n📊 [{split_name}] 阶段 {compare_part1} → {compare_part2} 有效样本：{total}")

            print("\n📚 数据来源统计：")
            for source, count in data_source_counter.items():
                print(f"  {source}: {count} 条，占比 {count / total:.1%}")

            image_type_counter = analyze_image_field(valid_samples, compare_part2)
            print("\n🖼️ 图像字段统计：")
            for k, v in image_type_counter.items():
                print(f"  {k}: {v} 条，占比 {v / total:.1%}")

            print("\n📊 拓展难度分布：")
            for level, count in difficulty_counter.items():
                print(f"  {level.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

            print("\n📊 拓展方法分布：")
            for method, count in extend_method_counter.items():
                print(f"  {method.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

            # 存每阶段分析结果
            stage_analysis_results[compare_part2] = {
                "valid_samples": valid_samples,
                "extend_index": extend_index
            }

        return samples, stage_analysis_results

    # 分析训练集与测试集
    train_valid, train_index_dict = analyze(train_samples, "Train")
    test_valid, test_index_dict = analyze(test_samples, "Test")

    # 保存文件
    with open(output_path.replace(".json", "_train.json"), "w", encoding="utf-8") as f_train:
        json.dump({"data": train_valid}, f_train, ensure_ascii=False, indent=2)

    with open(output_path.replace(".json", "_test.json"), "w", encoding="utf-8") as f_test:
        json.dump({"data": test_valid}, f_test, ensure_ascii=False, indent=2)

    print(f"\n✅ 已保存训练集 {len(train_valid)} 条，测试集 {len(test_valid)} 条 → {output_path}")
    # 合并训练 + 测试样本保存完整版本
    full_valid = train_valid + test_valid
    with open(output_path, "w", encoding="utf-8") as f_full:
        json.dump({"data": full_valid}, f_full, ensure_ascii=False, indent=2)

    print(f"✅ 已保存完整样本集：{len(full_valid)} 条 → {output_path}")
    # 合并阶段的 extend_index 统计指标
    combined_index = []
    for part2 in train_index_dict:
        combined_index.extend(train_index_dict[part2]["extend_index"])
    for part2 in test_index_dict:
        combined_index.extend(test_index_dict[part2]["extend_index"])

    analyze_and_export_combined(combined_index, "./analysis_output")
def analyze_stage_summary_accumulative(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        samples = json.load(f)["data"]

    all_stages = [
        ("origin_data", "extend_data_1"),
        ("extend_data_1", "extend_data_2"),
        ("extend_data_2", "extend_data_3")
    ]

    print(f"\n📂 分析数据集：{json_path}")

    for stage_idx in range(1, 4):  # 阶段 1, 2, 3
        current_stages = all_stages[:stage_idx]

        print(f"\n📦 [阶段 {stage_idx}] 比较：", end="")
        print(" + ".join([f"{a}→{b}" for a, b in current_stages]))

        difficulty_counter = Counter()
        extend_method_counter = Counter()
        image_type_counter = Counter()
        data_source_counter = Counter()
        total = 0
        null_method_count = 0

        for sample in samples:
            data_source = sample.get("origin_data", {}).get("data_source", "未标注")
            if data_source in ["OKVQA2", "OKVQA3"]:
                data_source = "OKVQA"
            elif data_source == "seedbench2":
                data_source = "SeedBench2"
            elif data_source == "MMBench2":
                data_source = "MMBench"
            data_source_counter[data_source] += 1

            for _, compare_part2 in current_stages:
                part2_data = sample.get(compare_part2, {})
                analysis = part2_data.get("static analysis", {})
                if analysis:
                    difficulty = analysis.get("difficulty", None)
                    method = analysis.get("extend_method", None)
                    if not method:
                        null_method_count += 1
                        method = "UNKNOWN"
                    difficulty_counter[difficulty] += 1
                    extend_method_counter[method] += 1
                    total += 1

                image_count = part2_data.get("image_count", 0)
                if image_count == 0:
                    image_type_counter["无图"] += 1
                elif image_count == 1:
                    image_type_counter["单图"] += 1
                elif image_count == 2:
                    image_type_counter["双图"] += 1
                else:
                    image_type_counter["多图"] += 1

        print("\n📚 数据来源统计：")
        for source, count in data_source_counter.items():
            print(f"  {source}: {count} 条，占比 {count / sum(data_source_counter.values()):.1%}")

        print("\n🖼️ 图像字段统计：")
        for k, v in image_type_counter.items():
            print(f"  {k}: {v} 条，占比 {v / total * 100:.1f}%")

        print("\n📊 拓展难度分布：")
        for level, count in difficulty_counter.items():
            print(f"  {level.upper():<6}: {count} 条，占比 {count / total * 100:.1f}%")

        print("\n📊 拓展方法分布：")
        for method, count in extend_method_counter.items():
            label = str(method).upper()
            print(f"  {label:<10}: {count} 条，占比 {count / total * 100:.1f}%")
        print(f"❗ 其中 UNKNOWN（原为 None 或缺失）数量：{null_method_count} 条，占比 {null_method_count / total * 100:.1f}%")
        print(f"\n✅ 阶段 {stage_idx} 总分析样本数：{total}")

if __name__ == "__main__":

    with open("final.json", "r", encoding="utf-8") as f:
        data = json.load(f)

    # 划分数据集
    random.seed(42)
    random.shuffle(data)
    split_idx = int(len(data) * 0.7)
    train_samples = data[:split_idx]
    test_samples = data[split_idx:]
    with open("final_train.json", "w", encoding="utf-8") as f_train:
        json.dump({"data": train_samples}, f_train, ensure_ascii=False, indent=2)

    with open("final_test.json", "w", encoding="utf-8") as f_test:
        json.dump({"data": test_samples}, f_test, ensure_ascii=False, indent=2)

    print("分析数据集：filtered_data_extend3_train.json")
    analyze_stage_summary_accumulative("filtered_data_extend3_train.json")
    print("分析数据集：filtered_data_extend3_test.json")
    analyze_stage_summary_accumulative("filtered_data_extend3_test.json")
