from tree_sitter import Language, Parser
import os
import subprocess

# 加载C语言解析器
C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)

# 读取路径并处理文件
def process_files_from_paths(paths_file):
    with open(paths_file, 'r') as file:
        file_paths = file.readlines()
    
    # 遍历每个文件路径
    for file_path in file_paths:
        file_path = file_path.strip()  # 去除路径中的空白字符
        if os.path.isfile(file_path):  # 确保路径是一个文件
            print(f"Processing file: {file_path}")
            with open(file_path, 'r') as f:
                original_code = f.read()

            # 转换代码中的for/while循环
            converted_code = convert_loops(original_code)

            # 格式化代码
            #formatted_code = format_code(converted_code)

            # 创建新的文件名（添加 "_converted" 后缀）
            #new_file_path = file_path.replace(".c", "_converted.c")
            
            # 将转换和格式化后的代码保存到新文件中
            with open(file_path, 'w') as f:
                f.write(converted_code)
            print(f"Converted and formatted file saved: {file_path}")
        else:
            print(f"Invalid file path: {file_path}")

# 转换for循环为while循环，或while循环为for循环
def convert_loops(code):
    code_bytes = bytes(code, 'utf-8')
    tree = parser.parse(code_bytes)
    root = tree.root_node
    
    # 存储需要替换的范围和内容
    replacements = []
    #level_for = 0
    #level_while = 0
    #level_dowhile = 0
    # 遍历所有节点
    def walk(node):
        nonlocal replacements
        #nonlocal level_for
        #nonlocal level_while
        #nonlocal level_dowhile
        
        # 转换for循环为while
        if node.type == 'for_statement':
            # 提取for循环各部分
            init = node.child_by_field_name('initializer')
            cond = node.child_by_field_name('condition')
            update = node.child_by_field_name('update')
            body = node.child_by_field_name('body')
            
            # 生成while循环代码
            new_code = ""
            if init:
                new_code += code_bytes[init.start_byte:init.end_byte].decode().rstrip(';') + ";\n"  # 去掉尾部的分号
            new_code += "while (" + code_bytes[cond.start_byte:cond.end_byte].decode() + ") {\n"
            new_code += code_bytes[body.start_byte:body.end_byte].decode().strip("{}") + "\n"
            if update:
                new_code += code_bytes[update.start_byte:update.end_byte].decode().rstrip(';') + ";\n"  # 去除尾部的分号
            new_code += "}"
            
            replacements.append((node.start_byte, node.end_byte, new_code))
            return 
        
        # 转换while循环为for（逆向转换）
        elif node.type == 'while_statement':
            cond = node.child_by_field_name('condition')
            body = node.child_by_field_name('body')

            # 生成 for 循环代码，仅转换 while 为 for
            new_code = f"for (; {code_bytes[cond.start_byte:cond.end_byte].decode()};)"
            #print(cond.start_byte)
            #print(cond.end_byte)

            if body:
                body_code = code_bytes[body.start_byte:body.end_byte].decode().strip()
                #if body_code.endswith("}"):
                    #body_code = body_code[:-1]  # 去掉多余的 }
                
                new_code += body_code  
            # 记录替换信息
            replacements.append((node.start_byte, node.end_byte, new_code))
            return
        
        elif node.type == 'do_statement':
            # 获取 do 循环的条件和体
            condition = node.child_by_field_name('condition')
            body = node.child_by_field_name('body')
    
            # 生成 for 循环的代码
            new_code = "for ("
    
            # 条件部分：获取循环条件
            if condition:
                condition_code = code_bytes[condition.start_byte:condition.end_byte].decode()
                new_code += "; " + condition_code + " ;){"
            else:
                new_code += ";;)"  # 如果没有条件，使用默认空条件
    
            # 循环体部分：获取并添加到新代码中
            if body:
                body_code = code_bytes[body.start_byte:body.end_byte].decode().strip("{}")
                new_code += "\n" + body_code
    
            new_code += "\n}"  # 结束 for 循环
    
            # 添加替换信息
            replacements.append((node.start_byte, node.end_byte, new_code))
            return


        
        for child in node.children:
            walk(child)
    walk(root)
    
    # 应用替换
    new_code = code_bytes.decode()
    for start, end, repl in sorted(replacements, key=lambda x: x[0], reverse=True):
        new_code = new_code[:start] + repl + new_code[end:]
    
    return new_code


# 格式化代码
def format_code(code):
    # 使用clang-format来格式化代码
    process = subprocess.Popen(['clang-format'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    formatted_code, _ = process.communicate(input=bytes(code, 'utf-8'))
    return formatted_code.decode()

# 使用路径文件
paths_file = './data/testThree-paths.txt'
process_files_from_paths(paths_file)
