import os

from main_simple_lib import *
import ast
import csv

class ComplexCodeTransformer(ast.NodeTransformer):

    def visit_For(self, node):
        # 先处理循环内部的节点
        self.generic_visit(node)

        # 获取迭代器名称和循环体
        iter_name = ast.unparse(node.target)

        # iter_name是muffin_patch
        # 构建新的print语句，输出迭代器名称和当前值
        print_code = f'print("item", "<t>","{iter_name}" , "<t>", {iter_name}, "<t>", ":" , "<END>")'
        print_stmt = ast.parse(print_code).body[0]
        
        # 将print语句添加到循环体的开始处
        node.body.insert(0, print_stmt)
        
        return node

    def visit_If(self, node):
        # 处理if语句，插入额外的输出
        self.generic_visit(node)  # 继续处理嵌套结构
        
        condition_str = ast.unparse(node.test)
        action_str = "; ".join(ast.unparse(n) for n in node.body)
        
        # 插入输出if条件和动作
        print_code = f'print("condition" , "<t>", "{condition_str}" , "<t>", ":", "<END>")' #f'print("if {condition_str}: {action_str}")'
        print_stmt = ast.parse(print_code).body[0]
        
        # 新的if语句体包含原动作和额外的print
        new_body = [print_stmt] + node.body
        node.body = new_body
        return node
    
    def visit_Assign(self, node):
        
        # 继续处理可能的嵌套结构
        self.generic_visit(node)

        # 处理可能的多目标赋值
        targets = [ast.unparse(t) for t in node.targets]
        target_str = ", ".join(targets)

        # 构建新的print语句，输出变量名和值
        # 注意：这里假设只有一个赋值目标，对于多目标赋值情况可能需要调整
        value_str = ast.unparse(node.value)
        func_name = ''
        if isinstance(node.value, ast.Call):
            # 是函数调用，打印函数调用信息
            func_name = self.get_func_name(node.value.func)

        # print_code = f'print("assign {target_str} =", {value_str}, "{func_name}")'
        print_code = f'print("assign", "<t>", \"{target_str}\" , "<t>", str({value_str}), "<t>", \"{func_name}\", "<END>")'
        print_stmt = ast.parse(print_code).body[0]

        # 返回一个包含原赋值语句和新print语句的代码块
        return [node, print_stmt]
    
    def visit_Expr(self, node):
        # 确保访问所有子节点
        self.generic_visit(node)
        
        # 检查当前节点是否表示一个方法调用
        if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute):
            # 检查方法名是否为"append"
            if node.value.func.attr in ['append','sort']:
                # 获取列表变量的名称
                list_name = self.get_func_name(node.value.func.value)
                
                # 如果我们成功获取了列表名称
                if list_name:
                    # 创建一个新的print语句来打印列表的内容
                    print_code = f'print("assign" ,"<t>", \"{list_name}\" ,"<t>", str({list_name}) ,"<t>", \"{node.value.func.attr}\", "<END>")'
                    print_stmt = ast.parse(print_code).body[0]
                    
                    # 将原始的.append()调用和新的打印语句封装在一个ast.Block中（如果Python版本不支持ast.Block，请改用适当的替代方法）
                    # 注意：Python标准AST并没有直接的Block节点，这里使用的是一种概念性的表达
                    # 实际上，这里需要返回一个修改后的节点列表，或者在更高层次上处理这种节点的添加
                    
                    # 返回修改后的节点（或节点列表）
                    return [node, print_stmt]  # 注意：这个返回值可能需要根据Python版本和上下文进行调整
        return node
    
    def get_func_name(self, func):
        # 根据函数调用节点的类型获取函数名
        if isinstance(func, ast.Name):
            return func.id
        elif isinstance(func, ast.Attribute):
            return func.attr
        return None
    
    def visit_Return(self, node):
        # 继续处理可能的嵌套结构
        self.generic_visit(node)

        # 只在return语句有返回值时添加print语句
        if node.value:
            # 构建新的print语句，输出返回值
            value_str = ast.unparse(node.value)
            print_code = f'print("Return value =","<t>", {value_str}, "<END>")'
            print_stmt = ast.parse(print_code).body[0]

            # 创建一个新的代码块，包含print语句和原return语句
            new_node = ast.If(test=ast.Constant(value=True), body=[print_stmt, node], orelse=[])
            ast.copy_location(new_node, node)  # 复制位置信息
            ast.fix_missing_locations(new_node)  # 修复任何缺失的位置信息
            return ast.copy_location(new_node, node)
        else:
            # 对于没有返回值的return语句，不做修改
            return node
    
    def generic_visit(self, node):
        # 对所有节点进行通用处理，确保遍历整个AST
        ast.NodeTransformer.generic_visit(self, node)
        return node


# ==================循环开始=================
filename = 'data/sub_results370.csv'
# 使用csv.reader读取文件，并跳过第一行
with open(filename, 'r') as file:
    GQA_datas = csv.reader(file, skipinitialspace=True)
    next(GQA_datas, None)  # 跳过第一行
    for GQA_data in GQA_datas:
        # 读取文本内容
        answer = GQA_data[0]
        GT = GQA_data[1]
        code = GQA_data[2]
        id = GQA_data[3]
        query = GQA_data[4]
        image_path = GQA_data[5].split('/')[-1]  

        # 读路径，保存中间变量用
        path = 'intermediate_variable/GQA_' + id
        
        # 改造代码
        code_prefix = "def execute_command(im):"
        code_postfix = "execute_command(image)"
        # code_with_save_path = f"ImagePatch(image,directory='{path}')"

        code = code.split("\n")
        code[0] = code_prefix
        code.append(code_postfix)

        code = "\n".join(code)
        # source_code = code.replace("ImagePatch(image)",code_with_save_path)
        source_code = re.sub(r"ImagePatch\((.*?)\)", r"ImagePatch(\1,directory='{}')".format(path), code)

        # 读图片
        image = load_image("./Data/GQA/images/images/" + image_path)
        # ------------------准备工作-----------------

        # ------------------AST部分-----------------
        # 解析AST、修改AST
        parsed_ast = ast.parse(source_code)
        transformer = ComplexCodeTransformer()
        modified_ast = transformer.visit(parsed_ast)
        # print(ast.unparse(modified_ast))

        # 重定向输出
        import sys
        original_stdout = sys.stdout 
        symbolic_trace = path + '/symbolic_trace.txt'
        if not os.path.exists(path):
            os.makedirs(path)
            file_st = open( symbolic_trace , 'w')
        else:
            raise ValueError("File already exists!!!")
        sys.stdout = file_st

        # 剪枝开始
        # 将修改后的AST编译并执行
        try:
            exec(compile(modified_ast, filename="<ast>", mode="exec"))
        except Exception as e:
            print(e)
            file_st.close()
            continue
        # print(ast.dump(ast.parse(source_code),indent=4))
        # 剪枝结束

        file_st.close()
        # 重定向结束
        # ------------------AST部分-----------------

        # ------------------symbolic_trace-----------------
        # 合并开始
        # symbolic_trace处理
        with open(symbolic_trace, 'r', encoding='utf-8') as file_st_af_me:
            text = file_st_af_me.read()

        text = text.replace('\n', '')
        text = text.replace('<END>', '\n')

        lines = text.split('\n')

        def process_lines_correct_order(lines):
            final_lines = []  # Final list to maintain the correct order
            variables_assignment = {}  # Tracks the latest assignment line index for each variable

            for i, line in enumerate(lines):
                if line.startswith("assign"):
                    _, var, _, _ = line.split("<t>" , maxsplit=3)
                    # Remove previous assignment if exists
                    if var in variables_assignment:
                        # Remove the previous line by index
                        # 这个prev_index表示的应该是，当前变量上一次出现是在第几行
                        prev_index = variables_assignment[var]
                        # print(prev_index)
                        final_lines[prev_index] = None  # Mark the previous line for removal
                    # Record the new index of the assignment
                    variables_assignment[var] = len(final_lines)
                    final_lines.append(line)
                else:
                    final_lines.append(line)
                # print(variables_assignment)

            return final_lines

        correct_order_lines = process_lines_correct_order(lines)
        correct_order_lines = [line if line is not None else '' for line in correct_order_lines]
        non_empty_lines = [line for line in correct_order_lines if line.strip() != '']
        import re
        def remove_spaces_in_parentheses(line):
                return re.sub(r'ImagePatch\((\d+),\s+(\d+),\s+(\d+),\s+(\d+)\)', r'ImagePatch(\1,\2,\3,\4)', line)
        processed_lines = [remove_spaces_in_parentheses(line) for line in non_empty_lines]
        joined_str = '\n'.join(processed_lines).replace("<t>","")

        symbolic_trace_after_merge = path + "/symbolic_trace_after_merge.txt"
        with open(symbolic_trace_after_merge, 'w', encoding='utf-8') as file:
            file.write(joined_str)
        # symbolic_trace处理结束
        # 合并结束
        # ------------------symbolic_trace-----------------