import json
import os

def load_json(file):
    with open(file,'r', encoding="utf8") as load_f:
        data = json.load(load_f)
        return data
    
def write_json(file, dict):
    with open(file, "w", encoding="utf8") as f:
        json.dump(dict, f, indent=4, ensure_ascii=False)




import re
import subprocess
import sys
from typing import Tuple, List
from difflib import SequenceMatcher

def run_code(code: str) -> Tuple[int, str, str]:
    """执行 Python 代码"""
    try:
        result = subprocess.run(
            [sys.executable, "-c", code],
            capture_output=True,
            timeout=10,
            text=True
        )
        return result.returncode, result.stdout, result.stderr
    except subprocess.TimeoutExpired:
        return -1, "", "Timeout"
    except Exception as e:
        return -2, "", str(e)


def extract_floats(text: str) -> List[float]:
    pattern = r'-?\d+\.?\d*(?:[eE][+-]?\d+)?'
    matches = re.findall(pattern, text)
    floats = []
    for match in matches:
        try:
            floats.append(float(match))
        except ValueError:
            continue
    return floats


def compare_code_interpreter(text: str) -> Tuple[float, int]:
    """
    比较文本中所有 <code> 块的执行结果与对应的 <interpreter> 内容
    
    参数:
        text: 包含 <code> 和 <interpreter> 标签的文本
        
    返回:
        tuple: (字符串匹配百分比, 不一致的浮点数个数)
    """
    # 提取所有 <code> 块中的 Python 代码
    code_pattern = r'<code>\s*```python\s*(.*?)\s*```\s*</code>'
    code_blocks = re.findall(code_pattern, text, re.DOTALL)
    
    # 提取所有 <interpreter> 块的内容
    interpreter_pattern = r'<interpreter>\s*(.*?)\s*</interpreter>'
    interpreter_blocks = re.findall(interpreter_pattern, text, re.DOTALL)
    
    if len(code_blocks) != len(interpreter_blocks):
        print(f"警告: 代码块数量({len(code_blocks)})与解释器块数量({len(interpreter_blocks)})不匹配")
    
    total_string_similarity = 0.0
    total_mismatched_floats = 0
    comparison_count = min(len(code_blocks), len(interpreter_blocks))
    
    for i in range(comparison_count):
        code = code_blocks[i]
        expected_output = interpreter_blocks[i].strip()
        
        
        # 执行代码
        returncode, stdout, stderr = run_code(code)
        
        if returncode != 0:
            print(f"代码块 {i+1} 执行失败:")
            print(f"错误: {stderr}")
            continue
        
        actual_output = stdout.strip()
        
        
        
        
        # 1. 计算字符串匹配百分比
        matcher = SequenceMatcher(None, actual_output, expected_output)
        string_similarity = matcher.ratio() * 100
        total_string_similarity += string_similarity
        
        # 2. 提取并比较浮点数
        actual_floats = extract_floats(actual_output)
        expected_floats = extract_floats(expected_output)
        
        # 比较浮点数（按位置对应）
        max_len = max(len(actual_floats), len(expected_floats))
        mismatched_count = 0
        
        for j in range(max_len):
            if j >= len(actual_floats) or j >= len(expected_floats):
                # 长度不一致，计为不匹配
                mismatched_count += 1
            else:
                # 使用相对容差和绝对容差比较浮点数
                actual_val = actual_floats[j]
                expected_val = expected_floats[j]
                
                # 允许小的误差（相对误差 < 1e-9 或绝对误差 < 1e-9）
                if abs(actual_val - expected_val) > 1e-9 and \
                   abs(actual_val - expected_val) / max(abs(expected_val), 1e-10) > 1e-9:
                    mismatched_count += 1
        
        total_mismatched_floats += mismatched_count
        
        
        # # 打印详细信息（可选）
        # print(f"\n=== 代码块 {i+1} ===")
        
        # print("-----------real-------------")
        # print(actual_output)
        # print("-----------llm-------------")
        # print(expected_output)
        # print("---------------------------")
        
        # print(f"字符串匹配度: {string_similarity:.2f}%")
        # print(f"浮点数不匹配个数: {mismatched_count}")
        # print(f"实际提取的浮点数: {actual_floats}")
        # print(f"期望的浮点数: {expected_floats}")
    
    # 计算平均字符串匹配百分比
    avg_string_similarity = total_string_similarity / comparison_count if comparison_count > 0 else 0.0
    
    
    
    return avg_string_similarity, total_mismatched_floats, len(code_blocks)


data = load_json('')


count_mismatch = 0
total_code_snip = 0
total_match_rate = 0

count_correct = 0
for d in data:
    match_rate, code_mismatch, code_num = compare_code_interpreter(d["formal_answer"])
    
    total_match_rate += match_rate
    count_mismatch += code_mismatch
    total_code_snip += code_num
    
    print("match rate: ", match_rate)
    if match_rate == 100:
        
        count_correct += 1
    else:
        print(d['id'])

print(count_correct)
    
    

print(total_match_rate/total_code_snip)
print(count_mismatch/total_code_snip)