import json
import os
import re
import subprocess
import sys
from typing import Tuple, List


def load_json(file):
    with open(file,'r', encoding="utf8") as load_f:
        data = json.load(load_f)
        return data
    
def write_json(file, dict):
    with open(file, "w", encoding="utf8") as f:
        json.dump(dict, f, indent=4, ensure_ascii=False)


def run_code(code: str) -> Tuple[int, str, str]:
    """执行 Python 代码"""
    try:
        result = subprocess.run(
            [sys.executable, "-c", code],
            capture_output=True,
            timeout=10,
            text=True
        )
        return result.returncode, result.stdout, result.stderr
    except subprocess.TimeoutExpired:
        return -1, "", "Timeout"
    except Exception as e:
        return -2, "", str(e)


def replace_interpreter_blocks(text: str, new_outputs: List[str]) -> str:
    """替换 <interpreter> 标签中的内容"""
    interpreter_pattern = r'<interpreter>\s*(.*?)\s*</interpreter>'
    matches = list(re.finditer(interpreter_pattern, text, re.DOTALL))
    
    if len(matches) != len(new_outputs):
        print(f"Warning: interpreter number({len(matches)})code number({len(new_outputs)})mismatch")
        return None  # 返回 None 表示不匹配
    
    # 从后往前替换，避免位置偏移
    for i in reversed(range(len(matches))):
        if i < len(new_outputs):
            start, end = matches[i].span()
            old_content = matches[i].group(0)
            new_content = f'<interpreter>\n{new_outputs[i]}\n</interpreter>'
            text = text[:start] + new_content + text[end:]
    print("replaced")
    return text


def check_code_format_validity(text: str) -> bool:
    """
    检查每个 <code> 标签内是否都包含正确的 ```python 代码块
    
    参数:
        text: 待检查的文本
        
    返回:
        bool: 如果所有 <code> 标签都包含正确的 ```python 代码块则返回 True，否则返回 False
    """
    # 提取所有 <code> 标签的内容
    code_tag_pattern = r'<code>(.*?)</code>'
    
    code_tag_contents = re.findall(code_tag_pattern, text, re.DOTALL)
    
    if len(code_tag_contents) == 0:
        return True  # 没有 <code> 标签，认为是有效的
    
    # 检查每个 <code> 标签内是否包含 ```python
    python_pattern = r'```python\s*(.*?)\s*```'
    for i, content in enumerate(code_tag_contents):
        python_blocks = re.findall(python_pattern, content, re.DOTALL)
        if len(python_blocks) == 0:
            print(f"警告: 第 {i+1} 个 <code> 标签中没有找到 ```python 代码块")
            return False
        if len(python_blocks) > 1:
            print(f"警告: 第 {i+1} 个 <code> 标签中找到 {len(python_blocks)} 个 ```python 代码块")
            return False
    
    return True


def compare_code_interpreter(text: str) -> Tuple[float, str]:
    """
    比较文本中所有 <code> 块的执行结果与对应的 <interpreter> 内容，并替换为真实执行结果
    
    参数:
        text: 包含 <code> 和 <interpreter> 标签的文本
        
    返回:
        tuple: (字符串匹配百分比, 替换后的新文本)
        特殊返回值:
        - (-1, ""): 代码执行失败
        - (-2, ""): 代码块格式无效（没有```python包裹或格式不正确）
        - (-3, ""): code块与interpreter块数量不匹配
    """
    # 首先检查代码格式是否有效
    if not check_code_format_validity(text):
        print(f"警告: 存在 <code> 标签但没有正确的 ```python 包裹")
        return -2, ""
    
    # 提取所有 <code> 块中的 Python 代码（必须有 ```python 包裹）
    code_pattern = r'<code>\s*```python\s*(.*?)\s*```\s*</code>'
    code_blocks = re.findall(code_pattern, text, re.DOTALL)
    
    # 提取所有 <interpreter> 块的内容
    interpreter_pattern = r'<interpreter>\s*(.*?)\s*</interpreter>'
    interpreter_blocks = re.findall(interpreter_pattern, text, re.DOTALL)
    
    # 检查 code 块和 interpreter 块数量是否匹配
    if len(code_blocks) != len(interpreter_blocks):
        print(f"警告: 代码块数量({len(code_blocks)})与解释器块数量({len(interpreter_blocks)})不匹配")
        return -3, ""
    
    # 如果没有代码块，直接返回
    if len(code_blocks) == 0:
        return 100.0, text
    
    new_outputs = []
    total_string_similarity = 0.0
    comparison_count = len(code_blocks)
    
    for i in range(comparison_count):
        code = code_blocks[i]
        expected_output = interpreter_blocks[i].strip()
        
        # 执行代码
        returncode, stdout, stderr = run_code(code)
        
        if returncode != 0:
            print(f"代码块 {i+1} 执行失败:")
            print(f"错误: {stderr}")
            return -1, ""  # 执行失败，返回错误标识
        
        actual_output = stdout.strip()
        new_outputs.append(actual_output)
        
        # 计算字符串匹配百分比
        from difflib import SequenceMatcher
        matcher = SequenceMatcher(None, actual_output, expected_output)
        string_similarity = matcher.ratio() * 100
        total_string_similarity += string_similarity
        
    # 计算平均字符串匹配百分比
    avg_string_similarity = total_string_similarity / comparison_count if comparison_count > 0 else 0.0
    
    if avg_string_similarity != 100:
        # 替换 <interpreter> 块中的内容
        new_text = replace_interpreter_blocks(text, new_outputs)
        if new_text is None:  # 替换时发现数量不匹配
            return -3, ""
    else:
        new_text = text
    
    return avg_string_similarity, new_text


# 加载数据
data = load_json('')

output_file = ''

print("loaded data: ", len(data))
try:
    processed_data = load_json(output_file)
    processed_id = [r['id'] for r in processed_data]   
    to_be_processed = [d for d in data if d['id'] not in processed_id]

    print("already processed: ", len(processed_data))
    
except:
    processed_data = []
    to_be_processed = data
    
print("to be processed: ", len(to_be_processed))

    
# 处理数据

failed_count = 0
invalid_format_count = 0
mismatch_count = 0
success_count = 0
save_interval = 200

for i, d in enumerate(to_be_processed):
    
    formal_cot = d.get("formalized_cot")

    # 检查 formalized_cot 是否为有效字符串
    if not isinstance(formal_cot, str):
        print(f"Item {d['id']}: 'formalized_cot' is not a string (got {type(formal_cot)}), skipping...")
        invalid_format_count += 1
        continue
    
    print(f"Processing item {i+1}/{len(to_be_processed)}")
    
    match_rate, new_formal_answer = compare_code_interpreter(d["formalized_cot"])
    
    if match_rate == -1:  # 执行失败
        failed_count += 1
        print(f"Item {d['id']} execution failed, skipping...")
        continue
    elif match_rate == -2:  # 代码块格式无效
        invalid_format_count += 1
        print(f"Item {d['id']} has invalid code format (missing ```python wrapper), skipping...")
        continue
    elif match_rate == -3:  # 数量不匹配
        mismatch_count += 1
        print(f"Item {d['id']} has mismatched code/interpreter blocks, skipping...")
        continue
    else:
        # 创建新的数据项，添加 match_rate
        new_item = d.copy()
        new_item["formalized_cot"] = new_formal_answer
        new_item["match_rate"] = match_rate
        new_item.pop("cot")
        processed_data.append(new_item)
        success_count += 1
        print(f"Item {d['id']} processed successfully, match rate: {match_rate:.2f}%")
        
        # 每处理 save_interval 个成功项，保存一次
        if success_count % save_interval == 0:
            print(f"Saving checkpoint after {success_count} successful items...")
            write_json(output_file, processed_data)



# 输出统计信息
print(f"\n{'='*50}")
print(f"成功处理: {success_count} 项")
print(f"执行失败并删除: {failed_count} 项")
print(f"格式无效并删除: {invalid_format_count} 项")
print(f"数量不匹配并删除: {mismatch_count} 项")
print(f"总删除: {failed_count + invalid_format_count + mismatch_count} 项")
print(f"处理后剩余: {len(processed_data)} 项")
print(f"{'='*50}\n")


# 保存到指定文件

write_json(output_file, processed_data)


print(f"处理完成，结果已保存到: {output_file}")