import json
import os
import re
from xml.parsers.expat import model
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import statistics
from transformers import AutoTokenizer

def get_tokenizer_for_model(model_name):
    """
    根据模型名称返回合适的tokenizer
    
    Args:
        model_name: 模型名称
    
    Returns:
        str: tokenizer模型名称
    """
    model_name_lower = model_name.lower()
    
    # L1模型系列
    if 'qwen3-8b-grpo-baseline' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/GRPO-Qwen3-8B-deepscaler-BASELINE/global_step_200/actor/huggingface"
    elif 'l1-8b-instruct' in model_name_lower:
        return '/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-Instruct-l1-strategy-deepscaler-luffy-style/global_step_100/actor/huggingface'
    elif 'l1-8b-ours-openr1' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-l1-strategy-openr1-test/global_step_200/actor/huggingface"
    elif 'l1-8b-ours-deepscaler-LUFFY-style' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-l1-strategy-deepscaler-luffy-style/global_step_200/huggingface"
    elif 'deepscaler-luffy-style-add1k' in model_name_lower:
        return '/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-l1-deepscaler-luffy-style-add1k/global_step_100/actor/huggingface'
    elif 'deepscaler-luffy-style-ori-length' in model_name_lower:
        return '/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-l1-strategy-deepscaler-luffy-style-ori-length/global_step_100/actor/huggingface'
    elif 'l1-8b' in model_name_lower or '0' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/l1"
    elif 'l1-1.5' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/l1-1.5b"
    elif 'grpo-qwen3-8b-deepscaler-ori-length-l1' in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/GRPO-Qwen3-8B-deepscaler-ori-length-l1/global_step_100/actor/huggingface"

    elif "seed-36b" in model_name_lower:
        return "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Seed-OSS-36B-Instruct"

def extract_model_and_budget_from_jsonl(filename):
    """
    从JSONL文件名提取模型名称和budget
    
    Args:
        filename: JSONL文件名（如 l1-8b-ours-deepscaler-LUFFY-style_512_test.jsonl）
    
    Returns:
        tuple: (model_name, budget)
    """
    # 移除.jsonl扩展名
    name = filename.replace('.jsonl', '')
    
    # 提取budget（在_test之前的数字）
    budget_match = re.search(r'_(\d+)_test$', name)
    if budget_match:
        budget = int(budget_match.group(1))
        # 提取模型名称（移除budget和_test部分）
        model_name = re.sub(r'_\d+_test$', '', name)
    else:
        # 尝试处理 模型名-budget_test 格式
        budget_match = re.search(r'-(\d+)_test$', name)
        if budget_match:
            budget = int(budget_match.group(1))
            # 提取模型名称（移除budget和_test部分）
            model_name = re.sub(r'-\d+_test$', '', name)
        else:
            # 如果没有_test后缀，尝试其他模式
            budget_match = re.search(r'-(\d+)$', name)
            budget = int(budget_match.group(1)) if budget_match else None
            model_name = re.sub(r'-\d+$', '', name) if budget else name
    
    return model_name, budget

def calculate_accuracy_and_length_from_jsonl(jsonl_path, model_name, budget):
    """
    从JSONL文件计算准确率和tokenizer解码后的平均长度
    
    Args:
        jsonl_path: JSONL文件路径
        model_name: 模型名称（用于选择tokenizer）
        budget: budget值（用于计算diff）
    
    Returns:
        dict: 包含准确率、平均长度、样本数等信息的字典
    """
    try:
        # 根据模型名称选择tokenizer
        tokenizer_model = get_tokenizer_for_model(model_name)
        print(f"  使用tokenizer: {tokenizer_model}")
        
        # 初始化tokenizer
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
        
        correct_count = 0
        total_count = 0
        total_token_length = 0
        total_char_length = 0
        total_word_length = 0
        total_diff = 0  # 新增：计算diff的总和
        
        # 存储所有长度用于统计分析
        token_lengths = []
        char_lengths = []
        word_lengths = []
        diffs = []  # 新增：存储每个样本的diff
        
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    
                    # 获取生成文本
                    generated_text = data.get('generated_text', '')
                    # generated_text = data.get('output', '')
                    correctness = data.get('correctness', None)
                    
                    if generated_text:
                        total_count += 1
                        
                        # 计算不同长度的度量
                        # 1. Token长度（使用tokenizer）
                        tokens = tokenizer.encode(generated_text)
                        token_length = len(tokens)
                        total_token_length += token_length
                        token_lengths.append(token_length)
                        
                        # 2. 字符长度
                        char_length = len(generated_text)
                        total_char_length += char_length
                        char_lengths.append(char_length)
                        
                        # 3. 单词长度（按空格分割）
                        word_length = len(generated_text.split())
                        total_word_length += word_length
                        word_lengths.append(word_length)
                        
                        # 4. 新增：计算diff（回答长度和budget长度的差的绝对值）
                        diff = abs(token_length - budget)
                        total_diff += diff
                        diffs.append(diff)
                        
                        # 统计正确性
                        if correctness is not None:
                            if correctness:
                                correct_count += 1
                    
                except json.JSONDecodeError:
                    print(f"Warning: Invalid JSON on line {line_num} in {jsonl_path}")
                    continue
                except Exception as e:
                    print(f"Error processing line {line_num} in {jsonl_path}: {e}")
                    continue
        
        # 计算统计信息
        accuracy = correct_count / total_count if total_count > 0 else 0
        avg_token_length = total_token_length / total_count if total_count > 0 else 0
        avg_char_length = total_char_length / total_count if total_count > 0 else 0
        avg_word_length = total_word_length / total_count if total_count > 0 else 0
        avg_diff = total_diff / total_count if total_count > 0 else 0  # 新增：平均diff
        
        # 计算标准差
        token_std = statistics.stdev(token_lengths) if len(token_lengths) > 1 else 0
        char_std = statistics.stdev(char_lengths) if len(char_lengths) > 1 else 0
        word_std = statistics.stdev(word_lengths) if len(word_lengths) > 1 else 0
        diff_std = statistics.stdev(diffs) if len(diffs) > 1 else 0  # 新增：diff标准差
        
        return {
            'accuracy': accuracy,
            'avg_token_length': avg_token_length,
            'avg_char_length': avg_char_length,
            'avg_word_length': avg_word_length,
            'avg_diff': avg_diff,  # 新增
            'token_std': token_std,
            'char_std': char_std,
            'word_std': word_std,
            'diff_std': diff_std,  # 新增
            'total_samples': total_count,
            'correct_samples': correct_count,
            'token_lengths': token_lengths,
            'char_lengths': char_lengths,
            'word_lengths': word_lengths,
            'diffs': diffs,  # 新增
            'tokenizer_used': tokenizer_model
        }
    
    except Exception as e:
        print(f"Error calculating metrics from {jsonl_path}: {e}")
        return None

def analyze_jsonl_files_performance(jsonl_files_dir, specific_models_budgets=None):
    """
    分析_test.jsonl文件，统计每个模型随budget变化的性能
    
    Args:
        jsonl_files_dir: JSONL文件目录
        specific_models_budgets: 可选，指定要分析的模型和budget组合
                               格式: [{'model': 'model_name', 'budget': budget_value}, ...]
                               如果为None，则分析所有文件
    
    Returns:
        dict: 按模型分组的性能数据
    """
    model_data = defaultdict(list)
    
    # 获取所有_test.jsonl文件
    jsonl_files = [f for f in os.listdir(jsonl_files_dir) if f.endswith('_test.jsonl')]
    
    print(f"找到 {len(jsonl_files)} 个_test.jsonl文件")
    
    # 如果指定了特定的模型和budget组合，进行过滤
    if specific_models_budgets:
        print(f"指定分析 {len(specific_models_budgets)} 个特定的模型-budget组合:")
        for item in specific_models_budgets:
            print(f"  - {item['model']} (budget: {item['budget']})")
        
        # 过滤文件
        filtered_files = []
        for jsonl_file in jsonl_files:
            model_name, budget = extract_model_and_budget_from_jsonl(jsonl_file)
            if budget is not None:
                # 检查是否在指定的组合中
                for target in specific_models_budgets:
                    if (model_name == target['model'] and budget == target['budget']):
                        filtered_files.append(jsonl_file)
                        break
        
        jsonl_files = filtered_files
        print(f"过滤后剩余 {len(jsonl_files)} 个文件")
    
    for jsonl_file in jsonl_files:
        jsonl_path = os.path.join(jsonl_files_dir, jsonl_file)
        model_name, budget = extract_model_and_budget_from_jsonl(jsonl_file)
        
        if budget is None:
            print(f"Warning: Could not extract budget from {jsonl_file}")
            continue
        
        print(f"处理文件: {jsonl_file} -> 模型: {model_name}, Budget: {budget}")
        
        # 计算性能指标 - 传递budget参数
        results = calculate_accuracy_and_length_from_jsonl(jsonl_path, model_name, budget)
        
        if results is not None:
            model_data[model_name].append({
                'budget': budget,
                'accuracy': results['accuracy'],
                'avg_token_length': results['avg_token_length'],
                'avg_char_length': results['avg_char_length'],
                'avg_word_length': results['avg_word_length'],
                'avg_diff': results['avg_diff'],  # 新增
                'token_std': results['token_std'],
                'char_std': results['char_std'],
                'word_std': results['word_std'],
                'diff_std': results['diff_std'],  # 新增
                'total_samples': results['total_samples'],
                'correct_samples': results['correct_samples'],
                'tokenizer_used': results['tokenizer_used'],
                'jsonl_file': jsonl_file
            })
    
    return model_data

def plot_budget_analysis_detailed(model_data, save_dir='plots'):
    """
    绘制详细的budget分析图表 - 三个子图，清楚标注budget
    
    Args:
        model_data: 模型数据字典
        save_dir: 保存图表的目录
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # 设置字体（使用英文）
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
    plt.rcParams['axes.unicode_minus'] = False
    
    # 创建三个子图：准确率、Token长度、diff（回答长度和budget长度的差的绝对值的平均数）
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))
    
    # 颜色和标记
    colors = plt.cm.tab10(np.linspace(0, 1, len(model_data)))
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']
    
    for i, (model_name, data) in enumerate(model_data.items()):
        # 按budget排序
        sorted_data = sorted(data, key=lambda x: x['budget'])
        budgets = [item['budget'] for item in sorted_data]
        accuracies = [item['accuracy'] for item in sorted_data]
        avg_token_lengths = [item['avg_token_length'] for item in sorted_data]
        avg_diffs = [item['avg_diff'] for item in sorted_data]  # 改为diff
        
        # 绘制准确率 - 添加label参数
        ax1.plot(budgets, accuracies, 
                color=colors[i], marker=markers[i % len(markers)], 
                linewidth=3, markersize=10, alpha=0.8, label=model_name,
                markeredgecolor='white', markeredgewidth=1.5)
        
        # 绘制Token长度
        ax2.plot(budgets, avg_token_lengths, 
                color=colors[i], marker=markers[i % len(markers)], 
                linewidth=3, markersize=10, label=model_name, alpha=0.8,
                markeredgecolor='white', markeredgewidth=1.5)
        
        # 绘制diff（回答长度和budget长度的差的绝对值的平均数）
        ax3.plot(budgets, avg_diffs, 
                color=colors[i], marker=markers[i % len(markers)], 
                linewidth=3, markersize=10, alpha=0.8, label=model_name,
                markeredgecolor='white', markeredgewidth=1.5)
    
    # 设置第一个子图（准确率）
    ax1.set_xlabel('Budget', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=14, fontweight='bold')
    ax1.set_title('Model Accuracy vs Budget', fontsize=16, fontweight='bold', pad=20)
    ax1.grid(True, alpha=0.3, linestyle='--')
    # ax1.legend(fontsize=12, frameon=True, fancybox=True, shadow=True, loc='lower right')
    ax1.set_xscale('log')
    ax1.set_ylim(0.3, 0.8)
    
    # 清楚标注x轴刻度
    ax1.set_xticks([512, 1024, 2048, 4096, 8192])
    ax1.set_xticklabels(['512', '1K', '2K', '4K', '8K'], fontsize=12, fontweight='bold')
    
    # 设置第二个子图（Token长度）
    ax2.set_xlabel('Budget', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Average Token Length', fontsize=14, fontweight='bold')
    ax2.set_title('Model Response Token Length vs Budget', fontsize=16, fontweight='bold', pad=20)
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.legend(fontsize=12, frameon=True, fancybox=True, shadow=True, loc='upper left')
    ax2.set_xscale('log')
    
    # 清楚标注x轴刻度
    ax2.set_xticks([512, 1024, 2048, 4096, 8192])
    ax2.set_xticklabels(['512', '1K', '2K', '4K', '8K'], fontsize=12, fontweight='bold')
    
    # 设置第三个子图（diff - 回答长度和budget长度的差的绝对值的平均数）
    ax3.set_xlabel('Budget', fontsize=14, fontweight='bold')
    ax3.set_ylabel('Average |Token Length - Budget|', fontsize=14, fontweight='bold')
    ax3.set_title('Model Response Length Deviation vs Budget', fontsize=16, fontweight='bold', pad=20)
    ax3.grid(True, alpha=0.3, linestyle='--')
    ax3.legend(fontsize=12, frameon=True, fancybox=True, shadow=True, loc='upper left')
    ax3.set_xscale('log')
    
    # 清楚标注x轴刻度
    ax3.set_xticks([512, 1024, 2048, 4096, 8192])
    ax3.set_xticklabels(['512', '1K', '2K', '4K', '8K'], fontsize=12, fontweight='bold')
    
    # 调整布局
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.3)
    
    plt.savefig(os.path.join(save_dir, 'budget_analysis_detailed_1.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 创建详细的统计表格
    print("\n" + "="*120)
    print("Detailed Budget Analysis Statistics")
    print("="*120)
    
    for model_name, data in model_data.items():
        print(f"\nModel: {model_name}")
        print("-" * 110)
        print(f"{'Budget':<8} {'Accuracy':<10} {'Token Len':<10} {'Char Len':<10} {'Word Len':<10} {'Avg Diff':<10} {'Samples':<8} {'Correct':<8} {'Tokenizer':<25}")
        print("-" * 110)
        
        sorted_data = sorted(data, key=lambda x: x['budget'])
        for item in sorted_data:
            budget = item['budget']
            accuracy = f"{item['accuracy']:.4f}"
            token_len = f"{item['avg_token_length']:.1f}"
            char_len = f"{item['avg_char_length']:.1f}"
            word_len = f"{item['avg_word_length']:.1f}"
            avg_diff = f"{item['avg_diff']:.1f}"  # 新增
            samples = item['total_samples']
            correct = item['correct_samples']
            tokenizer = item['tokenizer_used'].split('/')[-1][:20]  # 截取tokenizer名称
            
            print(f"{budget:<8} {accuracy:<10} {token_len:<10} {char_len:<10} {word_len:<10} {avg_diff:<10} {samples:<8} {correct:<8} {tokenizer:<25}")

def plot_length_distribution_comparison(model_data, save_dir='plots'):
    """
    绘制长度分布对比图 - 修复数据问题
    
    Args:
        model_data: 模型数据字典
        save_dir: 保存图表的目录
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # 设置字体（使用英文）
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
    plt.rcParams['axes.unicode_minus'] = False
    
    # 检查是否有数据
    total_data_points = 0
    for model_name, data in model_data.items():
        for item in data:
            if 'token_lengths' in item and item['token_lengths']:
                total_data_points += len(item['token_lengths'])
    
    if total_data_points == 0:
        print("Warning: No token length data found for distribution plot")
        return
    
    # 为每个模型创建一个子图
    n_models = len(model_data)
    if n_models <= 3:
        cols = n_models
        rows = 1
    else:
        cols = 3
        rows = (n_models + 2) // 3
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()
    
    for i, (model_name, data) in enumerate(model_data.items()):
        if i >= len(axes):
            break
            
        ax = axes[i]
        
        # 收集所有token长度数据
        all_token_lengths = []
        
        for item in data:
            if 'token_lengths' in item and item['token_lengths']:
                all_token_lengths.extend(item['token_lengths'])
        
        if all_token_lengths:
            print(f"Model {model_name}: {len(all_token_lengths)} data points, range: {min(all_token_lengths):.1f}-{max(all_token_lengths):.1f}")
            
            # 绘制直方图
            ax.hist(all_token_lengths, bins=30, alpha=0.7, color='skyblue', edgecolor='navy', linewidth=1)
            ax.set_xlabel('Token Length', fontsize=12, fontweight='bold')
            ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
            ax.set_title(f'{model_name}\nToken Length Distribution\n({len(all_token_lengths)} samples)', 
                        fontsize=14, fontweight='bold', pad=15)
            ax.grid(True, alpha=0.3, linestyle='--')
            
            # 添加统计信息
            mean_len = np.mean(all_token_lengths)
            median_len = np.median(all_token_lengths)
            ax.axvline(mean_len, color='red', linestyle='--', linewidth=2, 
                      label=f'Mean: {mean_len:.1f}')
            ax.axvline(median_len, color='green', linestyle='--', linewidth=2, 
                      label=f'Median: {median_len:.1f}')
            ax.legend(fontsize=10, frameon=True, fancybox=True, shadow=True)
        else:
            print(f"Warning: No token length data for model {model_name}")
            ax.text(0.5, 0.5, f'No data for\n{model_name}', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=14)
            ax.set_title(f'{model_name}\nNo Data Available', fontsize=14, fontweight='bold')
    
    # 隐藏多余的子图
    for i in range(len(model_data), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'length_distribution_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()



def get_specific_models_budgets_config():
    """
    返回要分析的特定模型和budget配置
    5个模型 × 5个budget = 25个文件
    """
    models = [
        # 'l1-1.5',
        # 'l1-8b', 
        # 'l1-8b-ours-deepscaler-LUFFY-style',
        # 'l1-8b-ours-deepscaler-luffy-style-valid-topp0.6-temp1.0'
        # 'grpo-qwen3-8b-deepscaler-ori-length-l1',
        # 'Qwen3-8b-grpo-baseline',
        # 'l1-8b-ours-openr1',
        # 'seed-36b',
        # '10'
        # 'l1-ours-deepscaler-luffy-style-ori-length',
        # 'l1-ours-deepscaler-luffy-style-add1k',
        # 'l1-8b-instruct-ours-deepscaler-luffy-style'
        "Qwen3-4B-polaris-1k"
        "Qwen3-4B-polrias30k-17k-temp1.4-step100-polaris-1k"
        "Qwen3-4B-polrias30k-16k-analysis-step100-polaris-1k"
        # "l1-8b-ours-deepscaler-luffy-style-tmp1.0-valid"
        # 'l1-8b-ours-deepscaler-luffy-style-tmp1.0'
    ]
    
    budgets = [512, 1024, 2048, 4096, 8192]
    
    specific_combinations = []
    for model in models:
        for budget in budgets:
            specific_combinations.append({
                'model': model,
                'budget': budget
            })
    
    return specific_combinations

def main():
    """
    主函数
    """
    # JSONL文件目录
    jsonl_files_dir = "../results_sep"
    # jsonl_files_dir = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b/DAPO-Qwen3-8B-l1-strategy-deepscaler-luffy-style-ori-length-new"
    # 获取要分析的特定模型和budget配置
    specific_models_budgets = get_specific_models_budgets_config()
    
    print("开始分析特定的5个模型和5个budget的25个文件...")
    print(f"JSONL文件目录: {jsonl_files_dir}")
    print("将根据模型名称自动选择合适的tokenizer")
    
    # 检查目录是否存在
    if not os.path.exists(jsonl_files_dir):
        print(f"错误: 目录 {jsonl_files_dir} 不存在")
        print("请将_test.jsonl文件放在正确的目录中，或修改jsonl_files_dir变量")
        return
    
    print(f"\n将分析 {len(specific_models_budgets)} 个特定的模型-budget组合:")
    for item in specific_models_budgets:
        print(f"  - {item['model']} (budget: {item['budget']})")
    
    # 分析数据 - 传递特定的模型和budget配置
    model_data = analyze_jsonl_files_performance(jsonl_files_dir, specific_models_budgets)
    
    if not model_data:
        print("没有找到有效的_test.jsonl文件数据")
        return
    
    print(f"\n找到 {len(model_data)} 个模型的数据:")
    for model_name, data in model_data.items():
        print(f"  - {model_name}: {len(data)} 个budget配置")
        # 显示使用的tokenizer
        if data:
            tokenizer = data[0]['tokenizer_used']
            print(f"    使用tokenizer: {tokenizer}")
    
    # 绘制详细图表
    plot_budget_analysis_detailed(model_data)
    
    print("\n分析完成！图表已保存到plots/目录")
    print("生成的文件:")
    print("  - budget_analysis_detailed.png (三个子图的详细分析)")
    
if __name__ == "__main__":
    main()