import datetime
import math
import os
import pathlib
from functools import partial
import warnings
import traceback
import json

import pandas as pd
import torch.multiprocessing as mp
from joblib import Memory
from num2words import num2words
import numpy as np
from omegaconf import OmegaConf
from rich.console import Console
from torch.utils.data import DataLoader
from tqdm import tqdm

from configs import config
from utils import seed_everything
import datasets
from code_analysis import is_followup_more_complex

import time
import matplotlib.pyplot as plt
from PIL import Image
from rich.pretty import pretty_repr
from torch import Tensor

import inspect
import ast
from collections import defaultdict, deque
import gc
import psutil
import copy
import base64
from io import BytesIO
def convert_to_nested_format(df):
    data_list = []
    for i, row in df.iterrows():
        # 构造 origin_data
        origin = {
            "data_source": "SeedBench2",  # 可根据需要从 row 里改
            "source_id": str(row["id"]),
            "question": row["query"],
            "image": row["img_path"],
            "golden_answer": row["possible_answers"] if isinstance(row["possible_answers"], list) else [row["possible_answers"]],
            "program": row["original_code"],
            "program_answer": row["original_result"],
            "static analysis": {
                "difficulty": "easy",  # 可以自定义设置
                "extend_method": None,
                "sample_diversity": None,
            }
        }
        expasion_index=row.get("expansion_index", "null")
        if not  isinstance(expasion_index, dict):
            extend_method = None
        else:
            extend_method="depth"  if expasion_index["depth_coef"] > expasion_index["width_coef"] else "width"
        # 构造 extend_data
        extend = {
            "question": row.get("followup_question", ""),
            "image": row["img_path"],
            "program": row.get("followup_code", ""),
            "program_answer": row.get("followup_result", ""),
            "static analysis": {
                "difficulty": "easy",
                "extend_method": extend_method,  # 默认给 depth
                "sample_diversity": "QI",
            }
        }

        # 合并成一个样本
        sample = {
            "origin_data": origin,
            "extend_data_1": extend  # 如有更多 follow-up 可继续添加 extend_data_2
        }

        data_list.append(sample)

    return {"data": data_list}

def pil_to_base64(image):
    if isinstance(image, Image.Image):
        try:
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            console.print(f"[green]图像转换为 Base64，长度：{len(img_str)}[/green]")
            return [img_str]
        except Exception as e:
            console.print(f"[red]Base64 编码失败：{e}[/red]")
            return []
    elif isinstance(image, list):
        # 如果是列表，尝试将其转换为 PIL.Image
        img_list=[]
        for img in image:
            if isinstance(img, Image.Image):
                try:
                    buffered = BytesIO()
                    img.save(buffered, format="JPEG")
                    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
                    console.print(f"[green]图像转换为 Base64，长度：{len(img_str)}[/green]")
                    img_list.append(img_str)
                except Exception as e:
                    console.print(f"[red]Base64 编码失败：{e}[/red]")
                    return []
        return img_list
    else:
        console.print(f"[red]错误：输入不是 PIL.Image，类型为 {type(image)}[/red]")
        return []
def report_memory(prefix=""):
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / 1024 / 1024
    print(f"💾 {prefix} Memory Usage: {mem:.2f} MB")

class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, pathlib.Path):
            return str(obj)
        elif isinstance(obj, set):
            return list(obj)
        elif isinstance(obj, (np.integer, np.floating)):
            return obj.item()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif hasattr(obj, "to_dict"):  # 通用结构化对象支持
            return obj.to_dict()
        elif isinstance(obj, Tensor):  # 如果是 Tensor 类型
            return obj.tolist()  # 转换为列表
        elif obj.__class__.__name__ == "ImagePatch":
            return str(obj)  # 或换成 obj.name / obj.to_dict()
        # 可以继续加更多类型判断
        return super().default(obj)

# See https://github.com/pytorch/pytorch/issues/11201, https://github.com/pytorch/pytorch/issues/973
# Not for dataloader, but for multiprocessing batches
mp.set_sharing_strategy('file_system')
queue_results = None

cache = Memory('cache/' if config.use_cache else None, verbose=0)
runs_dict = {}
seed_everything()
intermediate_variables_dict = {}
dependency_graph = {}
console = Console(highlight=False)

def save_intermediate_variable(name, value):
    intermediate_variables_dict[name] = value
def rich_escape(text):
    return str(text).replace('[', '\\[').replace(']', '\\]')
def topological_sort(dep_graph):
    if not dep_graph:
        return []
    indegree = defaultdict(int)
    graph = defaultdict(list)
    # 收集所有出现过的变量名
    all_vars = set(dep_graph.keys()) | {v for deps in dep_graph.values() for v in deps}

    # 确保所有变量都出现在图中（即使未定义）
    for var in all_vars:
        dep_graph.setdefault(var, [])

    for var, deps in dep_graph.items():
        for d in deps:
            graph[d].append(var)
            indegree[var] += 1
        if var not in indegree:
            indegree[var] = indegree.get(var, 0)

    queue = deque([n for n in indegree if indegree[n] == 0])
    sorted_order = []

    while queue:
        node = queue.popleft()
        sorted_order.append(node)
        for neighbor in graph[node]:
            indegree[neighbor] -= 1
            if indegree[neighbor] == 0:
                queue.append(neighbor)

    return sorted_order
def compute_depth_width_metrics_graph_based(dep_graph_old, dep_graph_new):
    """
    使用新定义：
    - 深度 = 最长依赖链长度（DFS）
    - 宽度 = 被依赖次数最多的变量的 in-degree
    """

    def compute_max_variable_in_degree(dep_graph):
        from collections import defaultdict
        in_degree_count = defaultdict(int)
        for src, targets in dep_graph.items():
            for target in targets:
                in_degree_count[target] += 1
        return max(in_degree_count.values(), default=0)

    def longest_path_length(dep_graph):
        memo = {}
        visited = set()

        def dfs(node):
            if node in memo:
                return memo[node]
            visited.add(node)
            max_depth = 0
            for neighbor in dep_graph.get(node, []):
                if neighbor not in visited:
                    max_depth = max(max_depth, dfs(neighbor))
            visited.remove(node)
            memo[node] = max_depth + 1
            return memo[node]

        return max((dfs(n) for n in dep_graph), default=0)

    # --- Compute depth ---
    depth_old = longest_path_length(dep_graph_old) if dep_graph_old else 0
    depth_new = longest_path_length(dep_graph_new) if dep_graph_new else 0
    depth_increase = max(depth_new - depth_old, 0)

    # --- Compute new width logic: max in-degree ---
    width_old = compute_max_variable_in_degree(dep_graph_old) if dep_graph_old else 0
    width_new = compute_max_variable_in_degree(dep_graph_new) if dep_graph_new else 0
    width_increase = max(width_new - width_old, 0)

    total = width_increase + depth_increase if (width_increase + depth_increase) > 0 else 1
    width_coef = round(width_increase / total, 2)
    depth_coef = round(depth_increase / total, 2)

    return {
        "width_increase": width_increase,
        "depth_increase": depth_increase,
        "depth_old": depth_old,
        "depth_new": depth_new,
        "width_old": width_old,
        "width_new": width_new,
        "width_coef": width_coef,
        "depth_coef": depth_coef,
    }
def find_dependencies_from_source(source_code):
    """
    静态分析给定源代码并提取其中的变量依赖关系（带依赖类型）。
    返回依赖关系的字典，键为变量名，值为依赖列表，每项为 (被依赖变量, 依赖类型)。

    支持依赖类型：assign, augassign, call, subscript, attribute, control, loop
    """
    dependencies = defaultdict(set)

    try:
        tree = ast.parse(source_code)

        for node in ast.walk(tree):

            # 1. 赋值语句 a = b + c，或 a = b[c]
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        target_var = target.id
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target_var].add((child.id, "assign"))
                            elif isinstance(child, ast.Subscript):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add((child.value.id, "subscript"))
                                if isinstance(child.slice, ast.Name):
                                    dependencies[target_var].add((child.slice.id, "subscript"))
                            elif isinstance(child, ast.Attribute):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add((child.value.id, "attribute"))
                            elif isinstance(child, ast.Call):
                                if isinstance(child.func, ast.Name):
                                    dependencies[target_var].add((child.func.id, "call"))
                                for arg in child.args:
                                    for arg_node in ast.walk(arg):
                                        if isinstance(arg_node, ast.Name):
                                            dependencies[target_var].add((arg_node.id, "call"))
                    elif isinstance(target, ast.Subscript):  # 索引赋值 x[b] = c
                        if isinstance(target.value, ast.Name):
                            x_var = target.value.id
                            if isinstance(target.slice, ast.Name):
                                dependencies[x_var].add((target.slice.id, "subscript"))
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target.value.id].add((child.id, "assign"))

            # 2. 增量赋值 c += d
            elif isinstance(node, ast.AugAssign):
                if isinstance(node.target, ast.Name):
                    target_var = node.target.id
                    dependencies[target_var].add((target_var, "augassign"))
                    for child in ast.walk(node.value):
                        if isinstance(child, ast.Name):
                            dependencies[target_var].add((child.id, "augassign"))

            # 3. 控制结构条件 if a > 0
            elif isinstance(node, ast.If):
                condition_vars = set()
                for child in ast.walk(node.test):
                    if isinstance(child, ast.Name):
                        condition_vars.add(child.id)
                for stmt in node.body + node.orelse:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for cond in condition_vars:
                                        dependencies[target.id].add((cond, "control"))

            # 4. 循环依赖 for x in range(n)
            elif isinstance(node, ast.For):
                loop_vars = set()
                for child in ast.walk(node.iter):
                    if isinstance(child, ast.Name):
                        loop_vars.add(child.id)
                for stmt in node.body:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for lv in loop_vars:
                                        dependencies[target.id].add((lv, "loop"))

    except Exception as e:
        print(f"Error analyzing source code: {e}")
    dep_dict={k: sorted(list(v)) for k, v in dependencies.items()}
    return {k: sorted(set(src for (src, _) in v)) for k, v in dep_dict.items()}

def show_all(lineno, value, valuename=None, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None, lastlineno=[-1]):
    time.sleep(0.1)

    global intermediate_variables_dict, dependency_graph

    if console_in is None:
        from rich.console import Console
        console_in = Console(highlight=False)

    # 自动获取变量名
    if valuename is None:
        try:
            frame = inspect.currentframe().f_back
            local_vars = frame.f_locals
            valuename = next((k for k, v in local_vars.items() if v is value), 'unknown')
        except Exception:
            valuename = "unknown"

    # 美化输出
    if lineno is not None and lineno != lastlineno[0]:
        console_in.rule(f"[bold]Line {lineno}[/bold]", style="chartreuse2")
        lastlineno[0] = lineno

    if isinstance(value, Image.Image):
        # console_in.print(f'{valuename} = [Image]')
        intermediate_variables_dict[valuename] = "Image"
    elif isinstance(value, list):
        for i, item in enumerate(value):
            show_all(None, item, f"{valuename}[{i}]", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"List of len {len(value)}"
    elif isinstance(value, dict):
        for k, v in value.items():
            show_all(None, v, f"{valuename}['{k}']", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"Dict of len {len(value)}"
    else:
        # console_in.print(f"{valuename} = {pretty_repr(value)}")
        intermediate_variables_dict[valuename] = value

    # print(f"Dependency graph: {json.dumps(dependency_graph, indent=2)}")
def analyze_and_export_combined(dict_list, output_dir="analysis_output"):
    os.makedirs(output_dir, exist_ok=True)
    field_value_counter = defaultdict(Counter)
    combined_rows = []
    total_samples = len(dict_list)

    # === 统计阶段 ===
    for d in dict_list:
        for k, v in d.items():
            field_value_counter[k][v] += 1

    # === 准备绘图和合并数据 ===
    fig_list = []
    field_names = []
    for field, counter in field_value_counter.items():
        df = pd.DataFrame(counter.items(), columns=["value", "count"])
        df["field"] = field
        df["percentage"] = (df["count"] / total_samples * 100).round(2)
        combined_rows.append(df)

    # === 合并所有统计数据为一个表 ===
    combined_df = pd.concat(combined_rows, ignore_index=True)
    combined_csv_path = os.path.join(output_dir, "combined_counts.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"📄 合并统计表已保存到: {combined_csv_path}")

    # === 合并图绘制 ===
    fields = combined_df["field"].unique()
    fig, axes = plt.subplots(len(fields), 1, figsize=(8, len(fields) * 3))

    if len(fields) == 1:
        axes = [axes]

    for i, field in enumerate(fields):
        sub_df = combined_df[combined_df["field"] == field]
        sub_df = sub_df.sort_values(by="value")  # 按值本身排序
        labels = [str(x) for x in sub_df["value"]]
        counts = sub_df["count"].tolist()
        percentages = sub_df["percentage"].tolist()

        ax = axes[i]
        bars = ax.bar(labels, counts)
        ax.set_title(f"{field}")
        ax.set_ylabel("Count")
        ax.set_xlabel("Value")
        ax.tick_params(axis='x', rotation=45)

        # 在柱子上方添加百分比标签
        for bar, pct in zip(bars, percentages):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width() / 2, height + 0.05, f"{pct:.1f}%", ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    merged_plot_path = os.path.join(output_dir, "all_fields_bar_chart.png")
    plt.savefig(merged_plot_path)
    print(f"🖼️ 合并图表已保存为: {merged_plot_path}")
    plt.close()

def split_codex_output(text):
    original_code = text.strip()
    followup_question = ""
    followup_code = ""
    # print(f'split_codex_output:{text}')
    # 尝试 Original Answer 格式
    if "Original Answer:" in text:
        text = text.replace("Original Answer:", "Original Code:")

    if "Original Code:" in text:
        parts = text.split("Original Code:")[1].split("Follow-up Question:")
        original_code = parts[0].strip()
        if len(parts) > 1:
            q_and_code = parts[1].split("Follow-up Code:")
            followup_question = q_and_code[0].strip()
            if len(q_and_code) > 1:
                followup_code = q_and_code[1].strip()
    return original_code, followup_question, followup_code
def my_collate(batch):
    # Avoid stacking images (different size). Return everything as a list
    to_return = {k: [d[k] for d in batch] for k in batch[0].keys()}
    return to_return
def run_program(parameters, queues_in_, input_type_, retrying=False):
    from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno,coerce_to_numeric,process_guesses
    from video_segment import VideoSegment
    from inspect import currentframe
    global queue_results, intermediate_variables_dict, dependency_graph

    code, sample_id, image, possible_answers, query = parameters
    # print(f'Running sample {sample_id} with original code: {code}')
    code_header = f'def execute_command_{sample_id}(' \
                  f'{input_type_}, possible_answers, query, ' \
                  f'ImagePatch, VideoSegment, ' \
                  'llm_query, bool_to_yesno, distance, best_image_match):\n' \
                  f'    # Answer is:'
    code = code_header + code
    # print(f'Running sample {sample_id} with mixed code: {code}')

    try:
        exec(compile(code, 'Codex', 'exec'), globals())
    except Exception as e:
        print(f'Sample {sample_id} failed at compilation time with error: {e}')
        try:
            with open(config.fixed_code_file, 'r') as f:
                fixed_code = f.read()
            code = code_header + fixed_code 
            exec(compile(code, 'Codex', 'exec'), globals())
        except Exception as e2:
            print(f'Not even the fixed code worked. Sample {sample_id} failed at compilation time with error: {e2}')
            return None, code

    queues = [queues_in_, queue_results]
    image_patch_partial = partial(ImagePatch, queues=queues)
    video_segment_partial = partial(VideoSegment, queues=queues)
    llm_query_partial = partial(llm_query, queues=queues)
    # print(f"image:{image}")

    try:
        result = globals()[f'execute_command_{sample_id}'](
            image, possible_answers, query,
            image_patch_partial, video_segment_partial,
            llm_query_partial, bool_to_yesno, distance, best_image_match)

    except Exception as e:
        # print full traceback
        traceback.print_exc()
        if retrying:
            return None, code, None, None
        print(f'Sample {sample_id} failed with error: {e}. Next you will see an "expected an indented block" error. ')
        # Retry again with fixed code
        new_code = "["  # This code will break upon execution, and it will be caught by the except clause
        result = run_program((new_code, sample_id, image, possible_answers, query), queues_in_, input_type_,
                             retrying=True)[0]
        return result, code, None, None
    # The function run_{sample_id} is defined globally (exec doesn't work locally). A cleaner alternative would be to
    # save it in a global dict (replace globals() for dict_name in exec), but then it doesn't detect the imported
    # libraries for some reason. Because defining it globally is not ideal, we just delete it after running it.
    if f'execute_command_{sample_id}' in globals():
        del globals()[f'execute_command_{sample_id}']  # If it failed to compile the code, it won't be defined
    print(f'Finished sample {sample_id} with result: {result}')

    intermediate_variables_temp = intermediate_variables_dict.copy()
    intermediate_variables_dict.clear()
    dependency_graph_temp = find_dependencies_from_source(code)
    return result, code, intermediate_variables_temp, dependency_graph_temp
def worker_init(queue_results_):
    global queue_results
    index_queue = mp.current_process()._identity[0] % len(queue_results_)
    queue_results = queue_results_[index_queue]
def main():
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass  # 已经设置则跳过

    from vision_processes import queues_in, finish_all_consumers, forward, manager
    from datasets import get_dataset

    batch_size = config.dataset.batch_size
    num_processes = min(batch_size, 50)

    if config.multiprocessing:
        queue_results_main = manager.Queue()
        queues_results = [manager.Queue() for _ in range(batch_size)]
    else:
        queue_results_main = None
        queues_results = [None for _ in range(batch_size)]

    model_name_codex = 'codellama' if config.codex.model == 'codellama' else 'codex'
    codex = partial(forward, model_name=model_name_codex, queues=[queues_in, queue_results_main])

    if config.clear_cache:
        cache.clear()

    if config.wandb:
        import wandb
        wandb.init(project="viper", config=OmegaConf.to_container(config))
        # log the prompt file
        wandb.save(config.codex.prompt)

    dataset = get_dataset(config.dataset)

    with open(config.codex.prompt) as f:
        base_prompt = f.read().strip()

    codes_all = None
    if config.use_cached_codex:
        results = pd.read_csv(config.cached_codex_path)
        codes_all = [r.split('# Answer is:')[1] for r in results['code']]
    # python -c "from joblib import Memory; cache = Memory('cache/', verbose=0); cache.clear()"
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,
                            collate_fn=my_collate)
    input_type = dataset.input_type

    all_results = []
    all_answers = []
    all_codes = []
    all_ids = []
    all_queries = []
    all_img_paths = []
    all_possible_answers = []
    all_query_types = []
    all_followup_questions = []
    all_followup_codes = []
    all_followup_results = []
    all_keep=[]
    all_halstead_o = []
    all_halstead_f = []
    all_cyclomatic_complexity_o = []
    all_cyclomatic_complexity_f = []
    all_codebert_similarity = []
    all_intermediate_variables_dict_original=[]
    all_intermediate_variables_dict_followup=[]
    all_dependency_graph_original=[]
    all_dependency_graph_followup=[]
    all_topological_sort_original=[]
    all_topological_sort_followup=[]
    all_expansion_index = []



    with mp.Pool(processes=num_processes, initializer=worker_init, initargs=(queues_results,)) \
            if config.multiprocessing else open(os.devnull, "w") as pool:
        try:
            n_batches = len(dataloader)

            for i, batch in tqdm(enumerate(dataloader), total=n_batches):
                gc.collect()
                # print("🔍 Sample Info:",batch)
                # Combine all queries and get Codex predictions for them
                # TODO compute Codex for next batch as current batch is being processed

                original_codes = []
                followup_questions = []
                followup_codes = []
                original_results = []
                followup_results = []
                extend_index = []
                if not config.use_cached_codex:
                    prompt,prompt0=[],[]
                    for query,answer, img, img_path in zip(batch['query'],batch['answer'], batch['image'], batch['image_name']):
                        prompt.append({
                            "query": query,
                            "answer":answer,
                            "image": pil_to_base64(img),
                            "image_name": img_path})
                        prompt0.append({
                            "query": query,
                            "answer": answer})
                    codes = codex(prompt=prompt, base_prompt=base_prompt, input_type=input_type)
                    if "def execute_command(image):\n" not in codes[0] and "I'm sorry." in codes[0]:
                        print("Codex refused to answer.Retrying")
                        codes = codex(prompt=prompt0, base_prompt=base_prompt, input_type=input_type)
                    for code_text in codes:
                        original_code, followup_question, followup_code = split_codex_output(code_text)
                        original_codes.append(original_code)
                        followup_questions.append(followup_question)
                        followup_codes.append(followup_code)
                else:
                    codes = codes_all[i * batch_size:(i + 1) * batch_size]
                    for code_text in codes:
                        original_code, followup_question, followup_code = split_codex_output(code_text)
                        original_codes.append(original_code)
                        followup_questions.append(followup_question)
                        followup_codes.append(followup_code)

                print(f'Batch {i+1}/{n_batches} - Codex generated')
                print(f'Query: {batch["query"]}')
                print(f'Original code: {original_codes[0]}')
                print(f'Follow-up question: {followup_questions[0]}')
                print(f'Follow-up code: {followup_codes[0]}')
                if config.execute_code:
                    if not config.multiprocessing:
                        for oc, fc, sample_id, img, possible_answers, query, followup_q in \
                                zip(original_codes, followup_codes, batch['sample_id'], batch['image'],
                                    batch['possible_answers'], batch['query'], followup_questions):
                            result_o = run_program([oc, sample_id, img, possible_answers, query], queues_in, input_type)
                            print(f"result_o: {result_o[0]}\n Intermediate_variables_dict o: {result_o[2]}\n"
                                  f"Dependency graph o: {json.dumps(result_o[3])}")
                            original_results.append(result_o)
                            if fc:
                                result_f = run_program([fc, sample_id, img, possible_answers, followup_q], queues_in,
                                                   input_type)
                                print(f"result_f: {result_f[0]}\n Intermediate_variables_dict f: {result_f[2]}\n"
                                      f"Dependency graph f: {json.dumps(result_f[3])}")
                                comparison=compute_depth_width_metrics_graph_based(result_o[3], result_f[3])
                                extend_index.append(comparison)
                                print(f"Expansion index:{json.dumps(comparison, indent=2)}")

                                followup_results.append(result_f)
                            else:
                                followup_results.append((None, fc, None, find_dependencies_from_source(fc)))
                    else:
                        original_results = list(pool.imap(partial(
                            run_program, queues_in_=queues_in, input_type_=input_type),
                            zip(original_codes, batch['sample_id'], batch['image'], batch['possible_answers'],
                                batch['query'])
                        ))
                        for result_o in original_results:
                            print(f"result_o: {result_o[0]}\n Intermediate_variables_dict o: {result_o[2]}\n"
                                  f"Dependency graph o: {json.dumps(result_o[3])}")


                        if followup_codes:
                            followup_results = list(pool.imap(partial(
                                run_program, queues_in_=queues_in, input_type_=input_type),
                                zip(followup_codes, batch['sample_id'], batch['image'], batch['possible_answers'],
                                    followup_questions)
                            ))
                            for result_f in followup_results:
                                print(f"result_f: {result_f[0]}\n Intermediate_variables_dict f: {result_f[2]}\n"
                                    f"Dependency graph f: {json.dumps(result_f[3])}")

                            comparison = compute_depth_width_metrics_graph_based(dependency_graph_o, dependency_graph_f)
                            extend_index.append(comparison)
                            print(f" Expansion index:{json.dumps(comparison, indent=2)}")

                        else:
                            followup_results = [(None, c, None,find_dependencies_from_source(c)) for c in followup_codes]
                            intermediate_variables_dict_f = intermediate_variables_dict.copy()
                            for result_f in followup_results:
                                print(f"result_f: {result_f[0]}\n Intermediate_variables_dict f: {result_f[2]}\n"
                                    f"Dependency graph f: {json.dumps(result_f[3])}")

                            comparison = compute_depth_width_metrics_graph_based(dependency_graph_o, dependency_graph_f)
                            extend_index.append(comparison)
                            print(f" Expansion index:{json.dumps(comparison, indent=2)}")
                else:
                    original_results = [(None, c, None,find_dependencies_from_source(c)) for c in original_codes]
                    followup_results = [(None, c, None,find_dependencies_from_source(c)) for c in followup_codes]

                    for oc,fc in zip(original_codes, followup_codes):
                        comparison = compute_depth_width_metrics_graph_based(find_dependencies_from_source(oc), find_dependencies_from_source(fc))
                        extend_index.append(comparison)
                        print(f"Expansion index:{json.dumps(comparison, indent=2)}")

                    warnings.warn(
                        "Not executing code! This is only generating the code. Set 'execute_code' to True if you want to run it.")

                all_results += [r[0] for r in original_results]
                all_codes += [r[1] for r in original_results]
                all_ids += batch['sample_id']
                all_answers += batch['answer']
                all_possible_answers += batch['possible_answers']
                all_query_types += batch['query_type']
                all_queries += batch['query']
                all_img_paths += [dataset.get_sample_path(idx) for idx in batch['index']]
                all_intermediate_variables_dict_original+=[r[2] for r in original_results]
                all_dependency_graph_original +=[r[3] for r in original_results]
                all_topological_sort_original += [topological_sort(r[3]) for r in original_results]

                # 存储拓展代码
                if config.followup:
                    all_followup_questions += followup_questions
                    all_followup_results += [r[0] for r in followup_results]
                    all_followup_codes += [r[1] for r in followup_results]
                    all_intermediate_variables_dict_followup += [r[2] for r in followup_results]
                    all_dependency_graph_followup += [r[3] for r in followup_results]
                    all_topological_sort_followup += [topological_sort(r[3]) for r in followup_results]
                    all_expansion_index += extend_index

                # if i % config.log_every == 0:
                #     try:
                #         accuracy = dataset.accuracy(all_results, all_answers, all_possible_answers, all_query_types)
                #         print(f'Accuracy at Batch {i}/{n_batches}: {accuracy}')
                #     except Exception as e:
                #         print(f'Error computing accuracy: {e}')

        except Exception as e:
            traceback.print_exc()
            print(f'Exception: {e}')
            print("Completing logging and exiting...")

        try:
            accuracy = dataset.accuracy(all_results, all_answers, all_possible_answers, all_query_types)
            print(f'Final accuracy: {accuracy}')
        except Exception as e:
            print(f'Error computing accuracy: {e}')

        # loop over each code pair
        # 计算 Halstead 复杂度
        if config.followup:
            for oc, fc in zip(all_codes, all_followup_codes):
                original_code_str = oc
                followup_code_str = fc
                keep, details = is_followup_more_complex(original_code_str, followup_code_str)
                # print("🧠 Complexity comparison result:", "✔️ Keep" if keep else "❌ Reject")
                all_keep.append(keep)
                all_halstead_o.append(details.get("halstead_o", -1))
                all_halstead_f.append(details.get("halstead_f", -1))
                all_cyclomatic_complexity_o.append(details.get("cc_o", -1))
                all_cyclomatic_complexity_f.append(details.get("cc_f", -1))
                all_codebert_similarity.append(details.get("similarity", -1))
            print(len(all_codes), len(all_halstead_o))

        if config.save:
            results_dir = pathlib.Path(config['results_dir']) / config.dataset.split
            results_dir.mkdir(parents=True, exist_ok=True)

            if not config.save_new_results:
                filename = 'results.csv'
            else:
                existing_files = list(results_dir.glob('results_*.csv'))
                if len(existing_files) == 0:
                    filename = 'results_0.csv'
                else:
                    filename = 'results_' + str(max([int(ef.stem.split('_')[-1]) for ef in existing_files if
                                                     str.isnumeric(ef.stem.split('_')[-1])]) + 1) + '.csv'
            print(len(all_results), len(all_answers), len(all_codes), len(all_ids), len(all_queries), len(all_img_paths),len(all_intermediate_variables_dict_original),len(all_dependency_graph_original))
            print('Saving results to', filename)
            df = pd.DataFrame({
                'id': all_ids,
                'query': all_queries,
                'img_path': all_img_paths,
                'possible_answers': all_possible_answers,
                'answer': all_answers,
                'original_code': all_codes,
                'original_result': all_results,
                'intermediate_variables_original': all_intermediate_variables_dict_original,
                'dependency_graph_original': all_dependency_graph_original,
                'topological_sort_original': all_topological_sort_original
            })
            df['original_result'] = df['original_result'].apply(str)
            df['img_path'] = df['img_path'].apply(str)
            df['intermediate_variables_original'] = df['intermediate_variables_original'].apply(
                lambda x: json.dumps(x, cls=CustomJSONEncoder))
            if config.followup:
                print("Saving followup results ")
                print(len(all_followup_questions),len(all_followup_codes), len(all_followup_results),  )
                df['followup_question'] =all_followup_questions
                df['followup_code'] =all_followup_codes
                df['followup_result'] =all_followup_results
                df['followup_result'] = df['followup_result'].apply(str)
                try:
                    print(len(all_intermediate_variables_dict_original), len(all_dependency_graph_followup), len(all_topological_sort_followup), len(all_expansion_index))
                    df['intermediate_variables_followup'] =all_intermediate_variables_dict_followup
                    df['dependency_graph_followup'] = all_dependency_graph_followup
                    df['topological_sort_followup'] = all_topological_sort_followup
                    df['expansion_index'] = all_expansion_index
                    df['intermediate_variables_followup'] = df['intermediate_variables_followup'].apply(
                        lambda x: json.dumps(x, cls=CustomJSONEncoder))
                except Exception as e:
                    print(f'Error saving intermediate variables or dependency graph: {e}')
                try:
                    print(len(all_halstead_o), len(all_halstead_f), len(all_cyclomatic_complexity_o), len(all_cyclomatic_complexity_f), len(all_codebert_similarity),len(all_keep))
                    df['halstead_original'] = all_halstead_o
                    df['halstead_followup'] = all_halstead_f
                    df['cyclomatic_complexity_original'] = all_cyclomatic_complexity_o
                    df['cyclomatic_complexity_followup'] = all_cyclomatic_complexity_f
                    df['codebert_similarity'] = all_codebert_similarity
                    df['expansion_index'] = all_expansion_index
                    df['whether_to_keep'] = all_keep
                except Exception as e:
                    print(f'Error saving halstead analysis: {e}')

            df.to_csv(results_dir / filename, header=True, index=False, encoding='utf-8')
            # 转为字典列表格式并保存为 JSON 文件
            # json_data = df.to_dict(orient='records')
            # with open(results_dir / (filename.replace('.csv', '.json')), 'w', encoding='utf-8') as f:
            #     json.dump(json_data, f, ensure_ascii=False, indent=2, cls=CustomJSONEncoder)

            data = convert_to_nested_format(df)
            with open(results_dir / (filename.replace('.csv', '.json')), "w", encoding="utf-8") as f:
                json.dump(data, f, ensure_ascii=False, indent=4)

            if config.wandb:
                import wandb
                wandb.log({'accuracy': accuracy})
                wandb.log({'results': wandb.Table(dataframe=df, allow_mixed_types=True)})

        finish_all_consumers()



if __name__ == '__main__':

    main()
    # 运行测试
    # df= pd.read_csv('./results/seedbench2-4000.csv')
    # data=convert_to_nested_format(df)
    # with open("./results/seedbench2-4000.json", "w", encoding="utf-8") as f:
    #     json.dump(data, f, ensure_ascii=False, indent=4)
    #
    # print("✅ JSON saved to output_data.json")