import json
from typing import List, Dict, Any, Optional
import re
import os
import matplotlib.pyplot as plt

def read_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """
    读取JSONL文件并返回包含所有JSON对象的列表
    
    Args:
        file_path: JSONL文件的路径
        
    Returns:
        包含所有JSON对象的列表
        
    Raises:
        FileNotFoundError: 如果文件不存在
        json.JSONDecodeError: 如果文件中有无效的JSON行
    """
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:  # 跳过空行
                continue
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError as e:
                raise json.JSONDecodeError(
                    f"Error parsing JSON on line {line_num}: {e.msg}",
                    e.doc,
                    e.pos
                )
    return data


from math_verify import parse

def extract_boxed_answer(text):
    """
    从文本中提取 \\boxed{} 格式的答案
    如果提取不到则返回 None，表示未完成
    """
    try:
        # 使用 math_verify 的 parse 函数来提取答案
        result = parse(text)
        return result
    except:
        return None

def is_incomplete(text):
    """
    判断模型回答是否未完成
    未完成的标志是模型最后没有类似 $\\n\\boxed{\\n$$ 的格式
    """
    if not text:
        return True
    
    # 检查是否包含 \\boxed{} 格式
    boxed_pattern = r'\\boxed\{[^}]*\}'
    if not re.search(boxed_pattern, text):
        return True
    
    # 尝试提取答案，如果提取失败则认为未完成
    extracted = extract_boxed_answer(text)
    return extracted is None

def calculate_incomplete_ratio(file_path):
    data = read_jsonl(file_path)
    
    incomplete_count = 0
    incorrect_count = 0
    for d in data:
        for i, g in enumerate(d['responses']):
            if d['accuracies'][i] == 0:
                if is_incomplete(g):
                    incomplete_count += 1
                incorrect_count += 1
    
    if incorrect_count == 0:
        return 0.0
    return incomplete_count / incorrect_count


def plot_incomplete_ratio(steps, ratios, output_path='incomplete_ratio_plot.png'):
    """
    绘制incomplete ratio随step变化的图表
    
    Args:
        steps: step编号列表
        ratios: incomplete ratio列表
        output_path: 输出图片路径
    """
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'sans-serif']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['figure.facecolor'] = 'white'
    plt.rcParams['axes.facecolor'] = 'white'
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # 绘制折线图
    ax.plot(steps, ratios, marker='o', linewidth=2, markersize=6, color='#2E86AB')
    
    # 设置标签和标题
    ax.set_xlabel('Step', fontsize=14, fontweight='bold')
    ax.set_ylabel('Incomplete Ratio', fontsize=14, fontweight='bold')
    ax.set_title('Incomplete Ratio vs Step', fontsize=16, fontweight='bold')
    
    # 设置网格
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # 设置坐标轴格式
    ax.set_xlim(left=min(steps) - 1, right=max(steps) + 1)
    ax.set_ylim(bottom=0, top=max(ratios) * 1.1 if max(ratios) > 0 else 1.0)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"图表已保存到: {output_path}")
    plt.close()


if __name__ == "__main__":
    import sys
    dir_path = sys.argv[1]
    # 收集所有step和对应的incomplete ratio
    step_ratios = []
    
    for file in os.listdir(dir_path):
        if file.endswith(".jsonl"):
            file_path = os.path.join(dir_path, file)
            try:
                # 从文件名提取step编号（例如：从 "10_16384.jsonl" 提取 10）
                step = int(file.split("_")[1])
                incomplete_ratio = calculate_incomplete_ratio(file_path)
                step_ratios.append((step, incomplete_ratio))
                print(f"Step {step}: Incomplete Ratio = {incomplete_ratio:.4f}")
            except (ValueError, IndexError) as e:
                print(f"警告: 无法从文件名 {file} 提取step编号，跳过")
                continue
    
    # 按step排序
    step_ratios.sort(key=lambda x: x[0])
    
    if not step_ratios:
        print("错误: 没有找到有效的数据文件")
    else:
        steps = [x[0] for x in step_ratios]
        ratios = [x[1] for x in step_ratios]
        
        # 绘制图表
        output_path = os.path.join(os.path.dirname(dir_path), 'incomplete_ratio_plot.png')
        plot_incomplete_ratio(steps, ratios, output_path)