import json
import inspect
import time
import ast
from collections import defaultdict, deque
from PIL import Image
import pandas as pd
from rich.console import Console
from rich.pretty import pretty_repr
from collections import defaultdict, Counter

import matplotlib.pyplot as plt
import os
# 存储中间变量与依赖图的全局字典
intermediate_variables_dict = {}

def save_intermediate_variable(name, value):
    intermediate_variables_dict[name] = value
def rich_escape(text):
    return str(text).replace('[', '\\[').replace(']', '\\]')
def topological_sort(dep_graph):
    indegree = defaultdict(int)
    graph = defaultdict(list)
    # 收集所有出现过的变量名
    all_vars = set(dep_graph.keys()) | {v for deps in dep_graph.values() for v in deps}

    # 确保所有变量都出现在图中（即使未定义）
    for var in all_vars:
        dep_graph.setdefault(var, [])

    for var, deps in dep_graph.items():
        for d in deps:
            graph[d].append(var)
            indegree[var] += 1
        if var not in indegree:
            indegree[var] = indegree.get(var, 0)

    queue = deque([n for n in indegree if indegree[n] == 0])
    sorted_order = []

    while queue:
        node = queue.popleft()
        sorted_order.append(node)
        for neighbor in graph[node]:
            indegree[neighbor] -= 1
            if indegree[neighbor] == 0:
                queue.append(neighbor)

    return sorted_order

def compute_depth_width_metrics_graph_based(dep_graph_old, dep_graph_new):
    """
    计算 followup 相比原始代码的依赖图在宽度和深度上的扩展：
    - 宽度：每个变量的依赖项是否增加（每增加一个变量依赖就 W+1）
    - 深度：最长依赖路径长度（通过DFS）
    """
    def longest_path_length(dep_graph):
        memo = {}
        visited = set()

        def dfs(node):
            if node in memo:
                return memo[node]
            visited.add(node)
            max_depth = 0
            for neighbor in dep_graph.get(node, []):
                if neighbor not in visited:
                    max_depth = max(max_depth, dfs(neighbor))
            visited.remove(node)
            memo[node] = max_depth + 1
            return memo[node]

        return max(dfs(node) for node in dep_graph)

    # 计算深度
    depth_old = longest_path_length(dep_graph_old) if dep_graph_old else 0
    depth_new = longest_path_length(dep_graph_new) if dep_graph_new else 0
    depth_increase = max(depth_new - depth_old, 0)

    # 宽度计算：每个变量依赖的数量增加就 W+1
    width_increase = 0
    all_keys = set(dep_graph_new.keys())
    for var in all_keys:
        old_deps = set(dep_graph_old.get(var, []))
        new_deps = set(dep_graph_new.get(var, []))
        if len(new_deps) > len(old_deps):
            width_increase += 1

    total = width_increase + depth_increase if (width_increase + depth_increase) > 0 else 1
    width_coef = round(width_increase / total, 2)
    depth_coef = round(depth_increase / total, 2)

    return {
        "width_increase": width_increase,
        "depth_increase": depth_increase,
        "depth_old": depth_old,
        "depth_new": depth_new,
        "width_coef": width_coef,
        "depth_coef": depth_coef,
    }
# 依赖分析
def find_dependencies_from_source(source_code):
    """
    静态分析给定源代码并提取其中的变量依赖关系（带依赖类型）。
    返回依赖关系的字典，键为变量名，值为依赖列表，每项为 (被依赖变量, 依赖类型)。

    支持依赖类型：assign, augassign, call, subscript, attribute, control, loop
    """
    dependencies = defaultdict(set)

    try:
        tree = ast.parse(source_code)

        for node in ast.walk(tree):

            # 1. 赋值语句 a = b + c，或 a = b[c]
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        target_var = target.id
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target_var].add((child.id, "assign"))
                            elif isinstance(child, ast.Subscript):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add((child.value.id, "subscript"))
                                if isinstance(child.slice, ast.Name):
                                    dependencies[target_var].add((child.slice.id, "subscript"))
                            elif isinstance(child, ast.Attribute):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add((child.value.id, "attribute"))
                            elif isinstance(child, ast.Call):
                                if isinstance(child.func, ast.Name):
                                    dependencies[target_var].add((child.func.id, "call"))
                                for arg in child.args:
                                    for arg_node in ast.walk(arg):
                                        if isinstance(arg_node, ast.Name):
                                            dependencies[target_var].add((arg_node.id, "call"))
                    elif isinstance(target, ast.Subscript):  # 索引赋值 x[b] = c
                        if isinstance(target.value, ast.Name):
                            x_var = target.value.id
                            if isinstance(target.slice, ast.Name):
                                dependencies[x_var].add((target.slice.id, "subscript"))
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target.value.id].add((child.id, "assign"))

            # 2. 增量赋值 c += d
            elif isinstance(node, ast.AugAssign):
                if isinstance(node.target, ast.Name):
                    target_var = node.target.id
                    dependencies[target_var].add((target_var, "augassign"))
                    for child in ast.walk(node.value):
                        if isinstance(child, ast.Name):
                            dependencies[target_var].add((child.id, "augassign"))

            # 3. 控制结构条件 if a > 0
            elif isinstance(node, ast.If):
                condition_vars = set()
                for child in ast.walk(node.test):
                    if isinstance(child, ast.Name):
                        condition_vars.add(child.id)
                for stmt in node.body + node.orelse:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for cond in condition_vars:
                                        dependencies[target.id].add((cond, "control"))

            # 4. 循环依赖 for x in range(n)
            elif isinstance(node, ast.For):
                loop_vars = set()
                for child in ast.walk(node.iter):
                    if isinstance(child, ast.Name):
                        loop_vars.add(child.id)
                for stmt in node.body:
                    for sub_node in ast.walk(stmt):
                        if isinstance(sub_node, ast.Assign):
                            for target in sub_node.targets:
                                if isinstance(target, ast.Name):
                                    for lv in loop_vars:
                                        dependencies[target.id].add((lv, "loop"))

    except Exception as e:
        print(f"Error analyzing source code: {e}")

    return {k: sorted(list(v)) for k, v in dependencies.items()}


def show_all(lineno, value, valuename=None, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None, lastlineno=[-1]):
    time.sleep(0.1)

    global intermediate_variables_dict, dependency_graph

    if console_in is None:
        from rich.console import Console
        console_in = Console(highlight=False)

    # 自动获取变量名
    if valuename is None:
        try:
            frame = inspect.currentframe().f_back
            local_vars = frame.f_locals
            valuename = next((k for k, v in local_vars.items() if v is value), 'unknown')
        except Exception:
            valuename = "unknown"

    # 美化输出
    if lineno is not None and lineno != lastlineno[0]:
        console_in.rule(f"[bold]Line {lineno}[/bold]", style="chartreuse2")
        lastlineno[0] = lineno

    if isinstance(value, Image.Image):
        # console_in.print(f'{valuename} = [Image]')
        intermediate_variables_dict[valuename] = "Image"
    elif isinstance(value, list):
        for i, item in enumerate(value):
            show_all(None, item, f"{valuename}[{i}]", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"List of len {len(value)}"
    elif isinstance(value, dict):
        for k, v in value.items():
            show_all(None, v, f"{valuename}['{k}']", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"Dict of len {len(value)}"
    else:
        # console_in.print(f"{valuename} = {pretty_repr(value)}")
        intermediate_variables_dict[valuename] = value

    # print(f"Dependency graph: {json.dumps(dependency_graph, indent=2)}")


def analyze_and_export_combined(dict_list, output_dir="analysis_output"):
    os.makedirs(output_dir, exist_ok=True)
    field_value_counter = defaultdict(Counter)
    combined_rows = []
    total_samples = len(dict_list)

    # === 统计阶段 ===
    for d in dict_list:
        for k, v in d.items():
            field_value_counter[k][v] += 1

    # === 准备绘图和合并数据 ===
    fig_list = []
    field_names = []
    for field, counter in field_value_counter.items():
        df = pd.DataFrame(counter.items(), columns=["value", "count"])
        df["field"] = field
        df["percentage"] = (df["count"] / total_samples * 100).round(2)
        combined_rows.append(df)

    # === 合并所有统计数据为一个表 ===
    combined_df = pd.concat(combined_rows, ignore_index=True)
    combined_csv_path = os.path.join(output_dir, "combined_counts.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"📄 合并统计表已保存到: {combined_csv_path}")

    # === 合并图绘制 ===
    fields = combined_df["field"].unique()
    fig, axes = plt.subplots(len(fields), 1, figsize=(8, len(fields) * 3))

    if len(fields) == 1:
        axes = [axes]

    for i, field in enumerate(fields):
        sub_df = combined_df[combined_df["field"] == field]
        sub_df = sub_df.sort_values(by="value")  # 按值本身排序
        labels = [str(x) for x in sub_df["value"]]
        counts = sub_df["count"].tolist()
        percentages = sub_df["percentage"].tolist()

        ax = axes[i]
        bars = ax.bar(labels, counts)
        ax.set_title(f"{field}")
        ax.set_ylabel("Count")
        ax.set_xlabel("Value")
        ax.tick_params(axis='x', rotation=45)

        # 在柱子上方添加百分比标签
        for bar, pct in zip(bars, percentages):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width() / 2, height + 0.05, f"{pct:.1f}%", ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    merged_plot_path = os.path.join(output_dir, "all_fields_bar_chart.png")
    plt.savefig(merged_plot_path)
    print(f"🖼️ 合并图表已保存为: {merged_plot_path}")
    plt.close()

def main():
    # 一个简单函数源码样例，模拟被执行并分析
    code1 = """
a, b = 1, 2
result = a + b
c = result * 2
show_all(1, result, "result")
show_all(2, c, "c")
d = c - b
show_all(3, d, "d")
"""
    code2 = """
a, b, e = 1, 2, 3
result = a + b + e
c = result * 2
show_all(1, result, "result")
show_all(2, c, "c")
d = c - b
f= d + e
show_all(3, d, "d")
show_all(4, f, "f")
"""
    target_list1,target_list2 = [],[]


    graph1=find_dependencies_from_source(code1)
    print("🔍 Dependency Graph:")
    print(json.dumps(graph1, indent=2))
    target_list1 += [graph1] if isinstance(graph1, dict) else graph1
    a, b = 1, 2
    result = a + b
    c = result * 2
    show_all(None,result, "result")
    show_all(None, c, "c")
    d = c - b
    show_all(None, d, "d")
    sorted_vars = topological_sort(dependency_graph)
    print("📈 Topological Order:", sorted_vars)
    print("\n🧠 Intermediate Variables:")
    print(json.dumps(intermediate_variables_dict, indent=2))
    target_list2 += [intermediate_variables_dict] if isinstance(intermediate_variables_dict, dict) else intermediate_variables_dict

    intermediate_variables_dict.clear()
    dependency_graph.clear()


    graph2=find_dependencies_from_source(code2)
    print("🔍 Dependency Graph:")
    print(json.dumps(dependency_graph, indent=2))
    target_list1 += [graph2] if isinstance(graph2, dict) else graph2

    a, b, e = 1, 2, 3
    result = a + b + e
    c = result * 2
    show_all(None, result, "result")
    show_all(None, c, "c")
    d = c - b
    f = d + e
    show_all(None, d, "d")
    show_all(None, f, "f")

    sorted_vars = topological_sort(dependency_graph)
    print("📈 Topological Order:", sorted_vars)

    print("\n🧠 Intermediate Variables:")
    print(json.dumps(intermediate_variables_dict, indent=2))
    target_list2 += [intermediate_variables_dict] if isinstance(intermediate_variables_dict,
                                                                dict) else intermediate_variables_dict

    Comparison=compute_depth_width_metrics_graph_based(graph1, graph2)
    print("\n📊 Metrics Comparison:")
    print(json.dumps(Comparison, indent=2))

    print(target_list1,target_list2)

def code_analysis():
    """
    进行代码分析，提取函数中的变量依赖关系。
    """
    df1 = pd.read_csv(r'./results/matched_answer.csv')
    df2 = pd.read_csv(r'./results/extend_index.csv')
    code1= df1['original_code'].tolist()
    code2= df2['followup_code'].tolist()
    print(len(code1), len(code2))
    extend_index=[]
    for i in range(len(code1)):
        code_header = f'def execute_command(' \
                      f', possible_answers, query, ' \
                      f'ImagePatch, VideoSegment, ' \
                      'llm_query, bool_to_yesno, distance, best_image_match):\n' \
                      f'    # Answer is:'

        graph1=find_dependencies_from_source(code1[i])
        graph2=find_dependencies_from_source(code2[i])
        if graph1 and graph2:
            extend_index.append(compute_depth_width_metrics_graph_based(graph1, graph2))
    print(len(extend_index))

    analyze_and_export_combined(extend_index,"./analysis_output")


if __name__ == '__main__':
    # main()
    code_analysis()