import json
import re
import pyarrow.parquet as pq
from pathlib import Path
import numpy as np
from typing import Union, List, Dict, Optional, Tuple, Any
import os
import tempfile
import subprocess


def convert_parquet_to_json(parquet_file: str, json_file: str, batch_size: int = 1000, max_rows: int = None) -> None:
    """
    将Parquet文件转换为JSON文件
    
    参数:
        parquet_file: 输入的Parquet文件路径
        json_file: 输出的JSON文件路径
        batch_size: 每次处理的行数
        max_rows: 限制转换的最大行数，None表示不限制
    """
    # 检查输入文件是否存在
    input_path = Path(parquet_file)
    if not input_path.exists():
        raise FileNotFoundError(f"输入文件不存在: {parquet_file}")
    
    # 创建输出目录（如果不存在）
    output_path = Path(json_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    print(f"开始将 {parquet_file} 转换为 {json_file}")
    
    try:
        # 自定义JSON编码器，处理NumPy数组和其他非标准类型
        class NumpyEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.ndarray):
                    return obj.tolist()  # 将NumPy数组转换为列表
                if isinstance(obj, np.generic):
                    return obj.item()    # 将NumPy标量转换为Python标量
                return super().default(obj)
        
        # 打开Parquet文件
        parquet_table = pq.ParquetFile(parquet_file)
        
        total_rows = 0
        with open(json_file, 'a', encoding='utf-8') as f:
            # 逐批次读取Parquet文件
            for batch in parquet_table.iter_batches(batch_size=batch_size):
                # 将批次转换为Pandas DataFrame
                df = batch.to_pandas()
                
                # 逐行处理并写入JSON
                for _, row in df.iterrows():
                    # 转换为JSON对象，使用自定义编码器
                    json_obj = row.to_dict()
                    json_line = json.dumps(json_obj, ensure_ascii=False, cls=NumpyEncoder)
                    
                    # 写入JSON行
                    f.write(json_line + '\n')
                    
                    total_rows += 1
                    
                    # 检查是否达到最大行数限制
                    if max_rows is not None and total_rows >= max_rows:
                        break
                
                # 检查是否达到最大行数限制
                if max_rows is not None and total_rows >= max_rows:
                    break
        
        print(f"转换完成! 共处理 {total_rows} 行数据")
        
    except Exception as e:
        print(f"转换过程中发生错误: {e}")
        raise 


def json_file_preprocess(json_path: str, output_json_path: str, process_case_count: int) -> Dict[str, float]:
    """
    处理JSON Lines文件，过滤特定verdict并统计结果
    
    参数:
        json_path: 输入JSON Lines文件路径
        output_json_path: 输出JSON Lines文件路径
        process_case_count: 需要处理的最大行数
        
    """
    # 检查输入文件是否存在
    input_path = Path(json_path)
    if not input_path.exists():
        raise FileNotFoundError(f"输入文件不存在: {json_path}")
    
    # 创建输出目录（如果不存在）
    output_path = Path(output_json_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    print(f"开始处理JSON文件: {json_path}")
    print(f"输出文件: {output_json_path}")
    print(f"计划处理行数: {process_case_count}")
    
    # 自定义JSON编码器，处理NumPy类型
    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()  # 将NumPy数组转换为列表
            return super().default(obj)
    # 自定义JSON解码器，处理NumPy类型
    class NumpyDecoder(json.JSONDecoder):
        def __init__(self, *args, **kwargs):
            super().__init__(object_hook=self.object_hook, *args, **kwargs)
            
        def object_hook(self, obj):
            # 处理可能的NumPy类型
            for key, value in obj.items():
                if isinstance(value, list):
                    try:
                        obj[key] = np.array(value)
                    except:
                        pass
            return obj
    
    processed_count = 0
    ok_count = 0
    non_ok_count = 0
    filtered_count = 0
    
    # 读取并处理文件
    with open(json_path, 'r', encoding='utf-8') as f, \
         open(output_json_path, 'w', encoding='utf-8') as out_f:
        
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
                
            try:
                # 解析JSON行
                json_obj = json.loads(line, cls=NumpyDecoder)
                
                # 过滤TIME_LIMIT_EXCEEDED
                verdict = json_obj.get('verdict', 'UNKNOWN')
                if verdict == 'TIME_LIMIT_EXCEEDED':
                    filtered_count += 1
                    continue
                
                # 写入输出文件
                out_f.write(json.dumps(json_obj, ensure_ascii=False, cls=NumpyEncoder) + '\n')
                processed_count += 1
                
                # 统计verdict
                if verdict == 'OK':
                    ok_count += 1
                else:
                    non_ok_count += 1
                    
                # 检查是否达到处理数量限制
                if processed_count >= process_case_count:
                    break
                    
            except json.JSONDecodeError as e:
                print(f"第 {line_num} 行解析错误: {e}")
            except Exception as e:
                print(f"处理第 {line_num} 行时发生错误: {e}")
    
    # 计算比例
    total_valid = ok_count + non_ok_count
    ok_ratio = ok_count / total_valid if total_valid > 0 else 0
    non_ok_ratio = non_ok_count / total_valid if total_valid > 0 else 0
    
    # 打印统计结果
    print(f"\n处理完成 - 统计信息:")
    print(f"总处理行数: {processed_count}")
    print(f"过滤的TIME_LIMIT_EXCEEDED行数: {filtered_count}")
    print(f"OK行数: {ok_count}, 比例: {ok_ratio:.2%}")
    print(f"非OK行数: {non_ok_count}, 比例: {non_ok_ratio:.2%}")


def load_data_json_lines(json_file_path: str) -> list[dict]:
    """
    以JSON Lines格式读取文件内容并转换为字典列表
    
    参数:
        json_file_path: JSON Lines文件的路径
    
    返回:
        list[dict]: 解析后的字典列表，如果出错则返回空列表
    """
    try:
        # 检查文件是否存在
        file_path = Path(json_file_path)
        if not file_path.exists():
            raise FileNotFoundError(f"文件不存在: {json_file_path}")
            
        print(f"开始读取JSON Lines文件: {json_file_path}")
        
        # 自定义JSON解码器，处理NumPy类型
        class NumpyDecoder(json.JSONDecoder):
            def __init__(self, *args, **kwargs):
                super().__init__(object_hook=self.object_hook, *args, **kwargs)
                
            def object_hook(self, obj):
                # 处理可能的NumPy类型
                for key, value in obj.items():
                    if isinstance(value, list):
                        try:
                            obj[key] = np.array(value)
                        except:
                            pass
                return obj
                
        data = []
        total_lines = 0
        
        with open(json_file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    try:
                        # 使用自定义解码器解析行
                        json_obj = json.loads(line, cls=NumpyDecoder)
                        data.append(json_obj)
                        total_lines += 1
                    except json.JSONDecodeError as e:
                        print(f"第 {total_lines+1} 行解析错误: {e}")
        return data
            
    except FileNotFoundError:
        print(f"错误: 文件 {json_file_path} 不存在")
        return []
    except Exception as e:
        print(f"错误: 读取文件 {json_file_path} 时发生未知错误: {e}")
        return []    


def test_code_preprocess(code_str: str, add_test_code:bool=False, add_code_block:bool=False, code_type:str="python") -> str:
    """
    规范化代码字符串：添加测试函数头并缩进后续代码
    
    Args:
        code_str: 原始代码字符串
        add_test_code: 是否修改为测试模式
        add_code_block: 是否封装为代码块
    
    Returns:
        处理后的代码字符串
    """

    # 分割代码行为列表
    if code_type == "cpp":
        replace_map = {
            "main": "mycode",
            "rank": "node_rank"
        }
        for old_str, new_str in replace_map.items():
            code_str = code_str.replace(old_str, new_str)
    lines = code_str.strip().split('\n')
    
    # 初始化处理后的行列表
    processed_lines = []
    in_mycode_block = False
    if code_type == "python":
        if add_code_block:
            processed_lines.append("```python")
        if add_test_code:
            processed_lines.append("from io import StringIO")
            processed_lines.append("import sys")
        processed_lines.append("def mycode():")
        
        # 遍历每一行进行处理
        for line in lines:
            # 计算缩进量（保留原始缩进风格）
            indent = len(line) - len(line.lstrip())
            indent_str = line[:indent]
            
            # 对非导入行增加4个空格缩进
            new_line = indent_str + '    ' + line.lstrip()
            processed_lines.append(new_line)
        if add_test_code:
            processed_lines.append("def run_method(input_str):")
            processed_lines.append('    ' + "old_stdin = sys.stdin")
            processed_lines.append('    ' + "sys.stdin = StringIO(input_str)")
            processed_lines.append('    ' + "capturedOutput = StringIO()")
            processed_lines.append('    ' + "sys.stdout = capturedOutput")
            processed_lines.append('    ' + "try:")
            processed_lines.append('    ' + '    ' + "mycode()")
            processed_lines.append('    ' + "finally:")
            processed_lines.append('    ' + '    ' + "sys.stdin = old_stdin")
            processed_lines.append('    ' + '    ' + "sys.stdout = sys.__stdout__")
            processed_lines.append('    ' + "output_str = capturedOutput.getvalue().rstrip('\\n')")
            processed_lines.append('    ' + "return output_str")
        if add_code_block:
            processed_lines.append("```")
    elif code_type == "cpp":
        if add_code_block:
            processed_lines.append("```cpp")
        if add_test_code:
            processed_lines.append("#include <sstream>")
            processed_lines.append("#include <string>")
            processed_lines.append('#include <cstdio>')
            processed_lines.append('#include <cstring>')
            processed_lines.append('#include <streambuf>')
            processed_lines.append('#include <iostream>')
            processed_lines.append('#include <unistd.h>')
            processed_lines.append('#include <sys/types.h>')
            processed_lines.append('#include <sys/stat.h>')
            processed_lines.append('#include <fcntl.h>')
        
        # 遍历每一行进行处理

        for line in lines:
            if "ios_base::sync_with_stdio(0)" in line:
                continue
            if "mycode()" in line and "//" not in line and "/*" not in line and not in_mycode_block:
                processed_lines.append('extern \"C\" {')
                in_mycode_block = True
                processed_lines.append("    " + line)
            elif line.startswith("}") and in_mycode_block:
                processed_lines.append("    " + line)
                in_mycode_block = False
                processed_lines.append("}")
            elif in_mycode_block:
                processed_lines.append("    " + line)
            else:
                processed_lines.append(line)
                
        if add_test_code:
            processed_lines.append('extern "C" {')
            processed_lines.append('    std::string run_mycode(const std::string& input_str) {')
            processed_lines.append('        FILE* orig_stdin = stdin;')
            processed_lines.append('        FILE* orig_stdout = stdout;')
            processed_lines.append('        std::streambuf* orig_cin = std::cin.rdbuf();')   
            processed_lines.append('        std::streambuf* orig_cout = std::cout.rdbuf();')   
            processed_lines.append('        std::string combined_output;')   
            processed_lines.append('        std::istringstream iss(input_str);')   
            processed_lines.append('        int pipefd[2];')   
            processed_lines.append('        pipe(pipefd);')
            processed_lines.append('        int stdout_copy = dup(fileno(stdout));')
            processed_lines.append('        std::cin.rdbuf(iss.rdbuf());')
            processed_lines.append('        FILE* strin = fmemopen(const_cast<char*>(input_str.c_str()),input_str.size(), "r");')
            processed_lines.append('        stdin = strin;')
            processed_lines.append('        dup2(pipefd[1], fileno(stdout));')
            processed_lines.append('        close(pipefd[1]);')
            processed_lines.append('        std::ostringstream oss;')
            processed_lines.append('        std::cout.rdbuf(oss.rdbuf());')
            processed_lines.append('        try {')
            processed_lines.append('            mycode();')
            processed_lines.append('            fflush(stdout);')
            processed_lines.append('            std::cout.flush();')
            processed_lines.append('        } catch (...) {')
            processed_lines.append('            std::cin.rdbuf(orig_cin);')
            processed_lines.append('            std::cout.rdbuf(orig_cout);')            
            processed_lines.append('            stdin = orig_stdin;')
            processed_lines.append('            dup2(stdout_copy, fileno(stdout));')
            processed_lines.append('            close(stdout_copy);')
            processed_lines.append('            close(pipefd[0]);')
            processed_lines.append('            fclose(strin);')
            processed_lines.append('            throw;')
            processed_lines.append('        }')
            processed_lines.append('        std::cin.rdbuf(orig_cin);')
            processed_lines.append('        std::cout.rdbuf(orig_cout);')
            processed_lines.append('        stdin = orig_stdin;')
            processed_lines.append('        dup2(stdout_copy, fileno(stdout));')
            processed_lines.append('        close(stdout_copy);')
            processed_lines.append('        combined_output = oss.str();')
            processed_lines.append('        close(pipefd[1]);')
            processed_lines.append('        char buffer[4096];')
            processed_lines.append('        ssize_t bytes_read;')
            processed_lines.append('        while ((bytes_read = read(pipefd[0], buffer, sizeof(buffer))) > 0) {')
            processed_lines.append('            combined_output.append(buffer, bytes_read);')
            processed_lines.append('        }')
            processed_lines.append('        close(pipefd[0]);')
            processed_lines.append('        fclose(strin);')
            processed_lines.append("        while (!combined_output.empty() && (combined_output.back() == '\\n' || combined_output.back() == '\\r')) {")
            processed_lines.append('            combined_output.pop_back();')
            processed_lines.append('        }')
            processed_lines.append('        return combined_output;')
            processed_lines.append('    }')
            processed_lines.append("}")
        if add_code_block:
            processed_lines.append("```")
    # 重新组合为字符串并返回
    return '\n'.join(processed_lines)


def get_code_from_response(response_list: list[str], start_str: str="```python", end_str: str="```", action_type:str="code") -> list[str]:
    """
    从响应列表中提取每个响应的代码块内容
    
    参数:
        response_list: 包含响应文本的列表
        start_str: 提取标记，默认为```
        action_type: 任务类型
    返回:
        list[str]: 每个响应对应的代码块内容列表
    """
    result = []
    # import pdb;pdb.set_trace()
    for response in response_list:
        # 检查是否包含代码块标记
        if start_str not in response:
            print("代码提取错误，没有提取到有效信息")
            print(response)
            result.append("代码提取错误，没有提取到有效信息")  # 无代码块，添加空字符串
            continue
            
        # 分割文本，提取代码块
        lines = response.splitlines()
        in_code_block = False
        # current_code = []
        current_code = []
        for line in lines:
            # 检测代码块开始或结束
            stripped_line = line.strip()
            if line.strip().startswith(start_str):
                if end_str in stripped_line and stripped_line.index(end_str) > stripped_line.index(start_str):
                    # 提取中间内容
                    inner_content = stripped_line[len(start_str):stripped_line.index(end_str)].strip()
                    if inner_content:
                        current_code.append(inner_content)
                    continue

                temp_code = []
                in_code_block = True
                continue
            if line.strip().startswith(end_str):
                current_code.append("\n".join(temp_code))
                in_code_block = False
                continue
            # 如果在代码块内，添加到当前代码
            if in_code_block:
                temp_code.append(line)
        
        if action_type == "code":
            # 将提取的代码添加到结果
            if current_code:
                result.append(current_code[-1])
            else:
                print("代码提取错误，没有提取到有效信息")
                print(response)
                result.append("代码提取错误，没有提取到有效信息")  # 有标记但无内容，添加空字符串
        elif action_type == "case":
            # import pdb; pdb.set_trace()
            if current_code and any('input_str' in s for s in current_code):
                result.append(next((item for item in reversed(current_code) if 'input_str' in item), None))
            else:
                print("代码提取错误，没有提取到有效信息")
                print(response)
                result.append("代码提取错误，没有提取到有效信息")
    
    return result


import ast

def get_case_from_response(response_list: list[str], return_type:str="str") -> list[list[str]]:
    """
    从响应列表中提取测试用例，寻找相邻的input_str和output_str行
    若中间存在其他有效行，会拼接到input_str行末尾
    确保每行以单引号结尾，必要时补全
    
    参数:
        response_list: 包含测试用例的响应列表
        
    返回:
        list[list[str]]: 每个响应对应的测试用例列表，每个测试用例是一个字符串
    """
    result = []
    
    for response in response_list:
        # 分割响应文本为行并过滤空行
        if response == "代码提取错误，没有提取到有效信息":
            result.append([response])
            continue

        lines = [line.strip() for line in response.splitlines() if line.strip()]
        lines = [line for line in lines if not (line.strip() and line.strip()[0] == '#')]
        
        # 按顺序查找input_str和output_str行
        cases = []
        i = 0
        while i < len(lines):
            line = lines[i]
            
            # 找到input_str行
            if line.startswith("input_str = '") or line.startswith("input_str ='") or line.startswith("input_str='") :
                input_line = line
                
                # 寻找下一个output_str行
                j = i + 1
                output_found = False
                intermediate_lines = []
                
                while j < len(lines):
                    if lines[j].startswith("output_str = '") or lines[j].startswith("output_str ='") or lines[j].startswith("output_str='") :
                        output_line = lines[j]
                        output_found = True
                        break
                    else:
                        # 收集中间行
                        intermediate_lines.append(lines[j])
                        j += 1
                
                # 处理中间行
                if intermediate_lines:
                    # 提取原始input_str的值部分（去掉前缀和可能存在的后缀单引号）
                    if input_line.startswith("input_str = '"):
                        value_start = len("input_str = '")
                    elif input_line.startswith("input_str ='"):
                        value_start = len("input_str ='")
                    elif input_line.startswith("input_str='"):
                        value_start = len("input_str='")
                    value_part = input_line[value_start:]
                    if value_part.endswith("'"):
                        value_part = value_part[:-1]
                    
                    # 拼接中间行内容
                    for inter_line in intermediate_lines:
                        value_part += f"{inter_line}"
                    if value_part.endswith("'"):
                        value_part = value_part[:-1]
                    
                    # 重建input_line
                    input_line = f"input_str = '{value_part}'"
                
                # 如果找到output_str行，处理并添加到结果
                if output_found:
                    # 确保input_line以单引号结尾
                    if not input_line.endswith("'") and not is_complete_expression(input_line):
                        input_line += "'"
                    
                    # 确保output_line格式正确
                    if not output_line.endswith("'") and not is_complete_expression(output_line):
                        output_line += "'"
                    
                    # 合并为测试用例
                    if return_type == "str":
                        case = f"        {input_line}\n        {output_line}"
                    elif return_type == "set":
                        case = (input_line[input_line.find('=') + 1:].lstrip(), output_line[output_line.find('=') + 1:].lstrip())
                    cases.append(case)
                    
                    # 移动指针到output_str行之后
                    i = j + 1
                    continue
            
            # 未找到input_str或已处理完当前测试用例，移动到下一行
            i += 1
        
        result.append(cases)
    
    return result

def is_complete_expression(code_line: str) -> bool:
    """
    使用Python的ast模块检查代码行是否是一个完整的表达式
    """
    try:
        # 尝试解析为完整的模块
        ast.parse(code_line)
        return True
    except SyntaxError:
        try:
            # 尝试解析为表达式语句
            ast.parse(f"print({code_line})")
            return True
        except SyntaxError:
            return False

def get_result_from_response(response_list: list[str]) -> list[list[str]]:
    """
    从响应列表中提取结果
    
    参数:
        response_list: 包含测试用例的响应列表
        
    返回:
        list[list[str]]: 每个响应对应的结果列表
    """
    result = []
    
    for response in response_list:
        result_temp = []
        # 分割响应文本为行
        lines = response.strip().splitlines()
        
        # 按每两行合并为一个测试用例
        cases = []
        for i in range(0, len(lines), 1):
            result_temp.append(lines[i].replace(' ', ''))
        
        result.append(result_temp)
    
    return result


def get_truth_result_and_code(response_list: list[str]) -> list[str]:
    # 暂不使用
    """
    从回复集合中提取每个回复的"result"和"test_code"字段值
    
    参数:
        response_list: 包含回复文本的列表，每个回复是一个JSON格式的字符串
        
    返回:
        trust_res_list: 包含所有回复中"result"字段值的列表
        test_code_list: 包含所有回复中"test_code"字段值的列表
    """
    truth_res_list = []
    test_code_list = []
    
    for response in response_list:
        # import pdb; pdb.set_trace()
        try:
            # 修复三重引号问题
            response_ = re.sub(r'"""([\s\S]*?)"""', lambda m: json.dumps(m.group(1)), response)
            # 将响应文本转换为字典
            response_dict = json.loads(response_)
            
            # 提取result字段
            truth_res = response_dict.get("result", "UNKNOWN")
            test_code = response_dict.get("test_code", None)
            if test_code and truth_res:
                truth_res_list.append(truth_res)
                test_code_list.append(test_code)
            
        except json.JSONDecodeError:
            print(f"输出不合法，编号为{response_list.index(response)}")
            # 处理JSON解析错误
            # import pdb; pdb.set_trace()
        except Exception as e:
            # 处理其他异常
            truth_res_list.append(f"ERROR: {str(e)}")
    
    return truth_res_list, test_code_list


def get_truth_result_from_data(data: list[dict], use_list: bool=True) -> list[str]:
    """
    从数据中提取每个任务的verdict状态，转换为OK/FAILED列表
    
    参数:
        data: 包含任务信息的字典列表，每个字典应包含"verdict"键
        use_list: 是否使用列表形式存储结果
        
    返回:
        list[str]: 按顺序记录的状态列表，值为"OK"或"FAILED"
    """
    results = []
    if use_list:
        for task in data:
            verdict_list = task.get("verdict_list", "")
            verdict_list = ['FAILED' if element != 'OK' else 'OK' for element in verdict_list]
            results.append(verdict_list)
    else:
        for task in data:
            verdict = task.get("verdict", "")
            if verdict == "OK":
                results.append("OK")
            else:
                results.append("FAILED")
    return results


def calculat_result_match_rate(predicted: list[str], actual: list[str], cov_dict:dict) -> dict:
    """
    计算并打印预测结果与真实结果的匹配率及相关统计指标
    
    参数:
        predicted: 预测结果列表
        actual: 真实结果列表
        
    返回:
        dict: 包含五个评估指标的字典
    """
    # 检查输入列表长度是否一致
    if len(predicted) != len(actual):
        raise ValueError("预测结果和真实结果的长度不一致")
    
    total_pairs = len(predicted)
    invalid_pairs = 0
    correct_matches = 0
    incorrect_matches = 0
    false_positives = 0  # 预测OK但实际非OK
    false_negatives = 0  # 预测非OK但实际OK
    true_positives = 0   # 预测OK且实际OK
    true_negatives = 0   # 预测非OK且实际非OK
    
    # 遍历每对结果
    for pred, act in zip(predicted, actual):
        # 检查是否包含UNKNOWN
        if "UNKNOWN" in (pred, act):
            invalid_pairs += 1
            incorrect_matches += 1
            continue

            
        # 计算匹配结果
        if pred == act:
            correct_matches += 1
            if pred == "OK":
                true_positives += 1
            else:
                true_negatives += 1
        else:
            incorrect_matches += 1
            # 计算误判类型
            if pred == "OK" and act != "OK":
                false_positives += 1  # 错误判断为正确
            elif pred != "OK" and act == "OK":
                false_negatives += 1  # 正确判断为错误
    
    # 计算各种比率
    valid_pairs = total_pairs - invalid_pairs
    match_rate = (correct_matches / total_pairs * 100)
    incorrect_rate = (incorrect_matches / total_pairs * 100)
    false_positive_rate = (false_positives / valid_pairs * 100) if valid_pairs > 0 else 0
    false_negative_rate = (false_negatives / valid_pairs * 100) if valid_pairs > 0 else 0
    
    # 计算精确率和召回率
    precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
    recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
    negatives_recall = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
    FP_rate = 100 - negatives_recall
    FN_rate = 100 - recall

    
    # 打印结果
    print(f"数据对总个数: {total_pairs}")
    print(f"非法个数: {invalid_pairs}")
    print(f"匹配正确个数: {correct_matches}")
    print(f"匹配正确比例: {match_rate:.2f}%")
    print(f"匹配错误个数: {incorrect_matches}")
    print(f"匹配错误比例: {incorrect_rate:.2f}%")
    print(f"真正率: {recall:.2f}%")
    print(f"真负率: {negatives_recall:.2f}%")
    print(f"假正率: {FP_rate:.2f}%")
    print(f"假负率: {FN_rate:.2f}%")
    print(f"精确率: {precision:.2f}%")
    if cov_dict:
        print(f"综合代码行覆盖率: {cov_dict['cov_line']:.2f}%")
        print(f"综合分支覆盖率: {cov_dict['cov_branch']:.2f}%")

    print(f"错误判断为正确的个数: {false_positives}")
    print(f"错误预测为正确的比例: {false_positive_rate:.2f}%")
    print(f"正确预测为错误的个数: {false_negatives}")
    print(f"正确预测为错误的比例: {false_negative_rate:.2f}%")
    
    # 返回结果字典
    return {
        "total": total_pairs,
        "pass_num": correct_matches,
        "fail_num": incorrect_matches,
        "acc": match_rate,
        "recall": recall,
        "negatives_recall": negatives_recall,
        "precision": precision,
        "FP radio": false_positive_rate,
        "FN radio": false_negative_rate,
        "cov_line": cov_dict["cov_line"] if cov_dict is not None else None,
        "cov_branch": cov_dict["cov_branch"] if cov_dict is not None else None
    }


def calculate_result_pass_rate(pred_list: list[str], actual_list: list[str], cov_dict:dict) -> dict:
    """
    计算并打印预结果的类型比例及相关统计指标
    
    参数:
        pred_list: 预测结果列表
        actual_list: 真实结果列表
        
    返回:
        dict: 包含五个评估指标的字典
    """
    
    if len(pred_list) != len(actual_list):
        raise ValueError("预测列表和实际列表的长度必须相同")
    
    total_tasks = len(pred_list)
    invalid_pairs = 0  # 执行非法的任务对数
    compilation_fail_count = 0  # 编译失败的任务数
    perfect_match_count = 0       # 完全匹配的任务数
    ok_correct_failed_wrong = 0   # OK全对但FAILED不全对的任务数
    ok_misclassified = 0          # 存在OK被误判为FAILED的任务数
    # import pdb; pdb.set_trace()
    
    for pred, actual in zip(pred_list, actual_list):
        
        
        if len(pred) == 0 or any(p not in ["OK", "FAILED"] for p in pred) or any(a not in ["OK", "FAILED"] for a in actual):
            # print(pred)
            invalid_pairs += 1
            if "COMPILATION_ERROR" in pred:
                compilation_fail_count+=1
            continue

        # 检查是否完全匹配
        if pred == actual:
            perfect_match_count += 1
            continue
        
        # 检查OK是否全对但FAILED不全对
        all_ok_correct = True
        has_failed = False
        failed_all_correct = True
        
        for p, a in zip(pred, actual):
            if a == 'OK':
                if p != 'OK':
                    all_ok_correct = False
                    ok_misclassified += 1
                    break  # 只要有一个OK被误判，就不再继续检查
            elif a == 'FAILED':
                has_failed = True
                if p != 'FAILED':
                    failed_all_correct = False
        
        # 如果OK全对且有FAILED且FAILED不全对
        if all_ok_correct and has_failed and not failed_all_correct:
            ok_correct_failed_wrong += 1
    
    # 计算三种比例
    perfect_ratio = (perfect_match_count / total_tasks * 100) if total_tasks > 0 else 0.0
    ok_correct_ratio = (ok_correct_failed_wrong / total_tasks * 100) if total_tasks > 0 else 0.0
    ok_misclassified_ratio = (ok_misclassified / total_tasks * 100) if total_tasks > 0 else 0.0
    compilation_pass_ratio = ((1 - compilation_fail_count / total_tasks) * 100) if total_tasks > 0 else 0.0
    
    # 打印结果
    print(f"数据对总个数: {total_tasks}")
    print(f"非法个数: {invalid_pairs}")
    print(f"编译失败个数: {compilation_fail_count}")
    print(f"编译成功比例: {compilation_pass_ratio:.2f}%")
    print(f"正确匹配比例: {perfect_ratio:.2f}%")
    print(f"样例不全面比例: {ok_correct_ratio:.2f}%")
    print(f"样例无效比例: {ok_misclassified_ratio:.2f}%")
    if cov_dict:
        print(f"综合代码行覆盖率: {cov_dict['cov_line']:.2f}%")
        if "cov_branch" in cov_dict:
            print(f"综合分支覆盖率: {cov_dict['cov_branch']:.2f}%")


    
    # 返回结果字典
    return {
        "total": total_tasks,
        "compilation_pass_ratio": compilation_pass_ratio,
        "match_radio": perfect_ratio,
        "P_P_radio": ok_correct_ratio,
        "F_x_radio": ok_misclassified_ratio,
        "cov_line": cov_dict["cov_line"] if cov_dict is not None else None,
        "cov_branch": cov_dict["cov_branch"] if "cov_branch" in cov_dict else None
    }

def calculate_result_code_success_rate(pred_list: list[str]) -> dict:
    """
    计算并打印预结果的类型比例及相关统计指标
    
    参数:
        pred_list: 预测结果列表

    返回:
        dict: 评估指标的字典
    """
    
    # 初始化计数器
    counts = {"OK": 0, "FAILED": 0, "UNKNOWN": 0}
    total = len(pred_list)

    
    for pred in pred_list:
        if pred == "OK":
            counts["OK"] += 1
        elif pred == "FAILED":
            counts["FAILED"] += 1
        else:
            counts["UNKNOWN"] += 1
    
    ok_ratio = counts["OK"] / total *100 if total > 0 else 0
    non_ok_ratio = counts["FAILED"] / total *100 if total > 0 else 0

    # 打印结果
    print(f"数据对总个数: {total}")
    print(f"非法个数: {counts['UNKNOWN']}")
    print(f"代码通过比例: {ok_ratio:.2f}%")
    print(f"代码未通过比例: {non_ok_ratio:.2f}%")


    
    # 返回结果字典
    return {
        "total": total,
        "pass_radio": ok_ratio,
        "non_pass_radio": non_ok_ratio
    }


def data_to_file(
    data: List[Dict],
    prompt_list: List[str],
    response_list: List[str],
    action_log_list: List[str],
    output_dir: str,
    result_data: Dict,
    predict_result: List[str],
    non_test_code: bool=False
) -> None:
    """
    将数据、响应、日志和指标写入JSON文件（使用os库进行路径操作）
    
    参数:
        data: 包含任务信息的字典列表
        prompt_list: 输入提示词的列表
        response_list: 响应结果列表
        action_log_list: 操作日志列表
        output_dir: 输出目录路径
        result_data: 包含统计指标的字典
        predict_result: 预测结果列表
    """
    # 创建输出目录（如果不存在）
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    # 1. 写入run.json文件
    run_json_path = os.path.join(output_dir, "run.json")
    error_json_path = os.path.join(output_dir, "error.json")
    with open(run_json_path, "w", encoding="utf-8") as f, \
        open(error_json_path, "w", encoding="utf-8") as error_f:
        for i in range(len(data)):
            if i >= len(response_list) or i >= len(action_log_list):
                print(f"警告: 索引{i}超出响应/日志列表长度，跳过")
                continue
            if non_test_code:
                # 构建任务数据
                    task_data = {
                        "problem": data[i].get("problem", ""),
                        "contestId": data[i].get("contestId", ""),
                        "index": data[i].get("index", ""),
                        "task name": data[i].get("name", ""),
                        "predict_result": predict_result[i],
                        "test_cases": str(data[i].get("test_cases", "")),
                        "prompt": prompt_list[i],
                        "response": response_list[i],
                        "action_log": action_log_list[i]
                    }
            else:
                if action_log_list is not None:
                    # 构建任务数据
                    task_data = {
                        "problem": data[i].get("problem", ""),
                        "contestId": data[i].get("contestId", ""),
                        "index": data[i].get("index", ""),
                        "task name": data[i].get("name", ""),
                        "ground_truth_list": ["OK" if item == "OK" else "FAILED" for item in data[i].get("verdict_list", [])],
                        "predict_result": predict_result[i],
                        "code_list": list(data[i].get("code_list", "")),
                        "prompt": prompt_list[i],
                        "response": response_list[i],
                        "action_log": action_log_list[i]
                    }
                else:   
                    # 构建任务数据
                    task_data = {
                        "problem": data[i].get("problem", ""),
                        "contestId": data[i].get("contestId", ""),
                        "index": data[i].get("index", ""),
                        "task name": data[i].get("name", ""),
                        "ground_truth_list": ["OK" if item == "OK" else "FAILED" for item in data[i].get("verdict_list", [])],
                        "predict_result": predict_result[i],
                        "code_list": list(data[i].get("code_list", "")),
                        "prompt": prompt_list[i],
                        "response": response_list[i],
                    }
            # 写入JSON行
            f.write(json.dumps(task_data, ensure_ascii=False) + "\n")

            error_predict_result = []
            error_code_list = []
            error_response_list = []
            error_action_log = []
            class NumpyEncoder(json.JSONEncoder):
                def default(self, obj):
                    if isinstance(obj, np.ndarray):
                        return obj.tolist()  # 将ndarray转换为列表
                    elif isinstance(obj, np.generic):
                        return obj.item()   # 将numpy标量转换为Python标量
                    return json.JSONEncoder.default(self, obj)
            if non_test_code:
                if predict_result[i] not in ["OK", "FAILED"]:
                    error_predict_result.append(predict_result[i])
                    error_code_list.append(response_list[i])
                    error_response_list.append(data[i]["test_cases"])
                    error_action_log.append(action_log_list[i])
                    error_data = {
                        "problem": data[i].get("problem", ""),
                        "contestId": data[i].get("contestId", ""),
                        "index": data[i].get("index", ""),
                        "task name": data[i].get("name", ""),
                        "predict_result": error_predict_result,
                        "code_list": error_code_list,
                        "response": error_response_list,
                        "action_log": error_action_log
                    }
                    error_f.write(json.dumps(error_data, ensure_ascii=False, cls=NumpyEncoder) + "\n")
            else:
                if any(p not in ["OK", "FAILED"] for p in predict_result[i]):
                    for index, pred in enumerate(predict_result[i]):
                        if pred != "OK" and pred != "FAILED":
                            error_predict_result.append(pred)
                            error_code_list.append(data[i]["code_list"][index])
                            error_response_list.append(response_list[i])
                            error_action_log.append(action_log_list[i])
                    error_data = {
                        "problem": data[i].get("problem", ""),
                        "contestId": data[i].get("contestId", ""),
                        "index": data[i].get("index", ""),
                        "task name": data[i].get("name", ""),
                        "predict_result": error_predict_result,
                        "code_list": error_code_list,
                        "response": error_response_list,
                        "action_log": error_action_log
                    }
                    error_f.write(json.dumps(error_data, ensure_ascii=False, cls=NumpyEncoder) + "\n")

    
    # 2. 写入final_metrics.json文件
    metrics_json_path = os.path.join(output_dir, "final_metrics.json")
    if non_test_code:
        metrics_data = {
            "custom": {
                "total": result_data.get("total", 0),
                "pass_radio": result_data.get("pass_radio", 0.0),
                "non_pass_radio": result_data.get("non_pass_radio", 0.0)
            }
        }
    else:
        metrics_data = {
            "custom": {
                "total": result_data.get("total", 0),
                "match_radio": result_data.get("match_radio", 0.0),
                "P_P_radio": result_data.get("P_P_radio", 0.0),
                "F_x_radio": result_data.get("F_x_radio", 0.0),
                "cov_branch": result_data.get("cov_branch", 0.0),
                "cov_line": result_data.get("cov_line", 0.0)
            }
        }
    
    
    with open(metrics_json_path, "w", encoding="utf-8") as f:
        json.dump(metrics_data, f, ensure_ascii=False, indent=2)

    #写入predict_result.txt文件
    predict_result_txt_path = os.path.join(output_dir, "predict_result.txt")
    content = "[" + ", ".join(str(item) for item in predict_result) + "]"
    
    # 写入文件
    with open(predict_result_txt_path, "w", encoding="utf-8") as f:
        f.write(content)
    
    print(f"数据已成功写入: {output_dir}")
    print(f"run.json: 包含{len(data)}条任务记录")
    print(f"final_metrics.json: 包含完整统计指标")
    print(f"predict_result.txt: 包含预测结果")

def get_result_from_code(task_data: Dict, test_code: str):
    """
    执行pytest测试代码对，获取响应和日志
    
    参数:
        task_data: 包含任务信息的字典
        test_code: 待测代码
        
    返回:
        list: 包含响应和日志的列表
    """
    # import pdb; pdb.set_trace()
    code_list = task_data.get('code_list', None)
    result_list = []
    for origin_code in code_list:
        if origin_code is not None:
            origin_code = test_code_preprocess(origin_code, add_test_code=True)
        else:
            origin_code = ""
        # import pdb;pdb.set_trace()
        with tempfile.TemporaryDirectory() as temp_dir:
            # 创建 mycode.py 文件
            code_path = os.path.join(temp_dir, "tempcode.py")
            with open(code_path, "w", encoding="utf-8") as f:
                f.write(origin_code)

            # 创建 test_mycode.py 文件
            test_path = os.path.join(temp_dir, "test_tempcode.py")
            with open(test_path, "w", encoding="utf-8") as f:
                f.write(test_code)

            cmd = [
                "pytest", 
                "--cov=tempcode", 
                "--cov-branch",
                "--cov-report=term-missing",
                "-v",
                test_path
            ]
            try:
                result = subprocess.run(
                    cmd,
                    cwd=temp_dir,
                    capture_output=True,
                    text=True,
                    timeout=30
                )
            except subprocess.TimeoutExpired as e:
                result = f"执行测试时发生超时错误: {str(e)}"
            except Exception as e:
                result = f"执行测试时发生未知错误: {str(e)}"
            finally:
                result_list.append(result)
    return result_list 

      
def get_code_from_case(test_cases):
    """
    将测试用例列表转换为 pytest 格式的测试代码字符串
    
    参数:
    test_cases (list): 测试用例列表，每个元素是包含 input 和 output 的字典
    
    返回:
    str: 格式化后的 pytest 测试代码
    """
    code = [
        "from tempcode import run_method",
        "import pytest",
        "",
        "class TestCode():",
    ]
    
    for i, case in enumerate(test_cases, 1):
        # 转义输入和输出字符串中的换行符
        input_str = case["input"].replace("\n", "\\n")
        output_str = case["output"].replace("\n", "\\n")
        
        # 添加测试方法
        code.extend([
            f"    def test_{i}(self):",
            f"        input_str = '{input_str}'",
            f"        output_str = '{output_str}'",
            f"        assert run_method(input_str) == output_str",
            ""
        ])
    
    # 移除最后一个空行并连接所有行
    return "\n".join(code[:-1])