import json
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse
import os
import re

def plot_error_vs_depth_barchart(results, output_dir, dataset_name, args):
    """
    绘制误差-深度分组柱状图。
    """
    print("\n--- Generating Error vs. Depth Bar Chart ---")
    
    model_names = list(results.keys())
    
    # --- 数据处理部分 (与原代码相同) ---
    # 从模型名称中解析层数作为 x 轴
    depths = []
    for name in model_names:
        match = re.search(r'(\d+)L_', name)
        if match:
            depths.append(int(match.group(1)))
        else:
            # 如果找不到匹配，则按顺序编号，确保代码健壮性
            depths.append(len(depths))
    
    # 根据深度进行排序
    sorted_indices = np.argsort(depths)
    depths = np.array(depths)[sorted_indices]
    model_names = np.array(model_names)[sorted_indices]

    baseline_model_name = model_names[0]
    baseline_errors = results[baseline_model_name]['avg_errors']
    
    improvements_complex = []
    improvements_simple = []
    p_values_complex = []

    for name in model_names:
        current_errors = results[name]['avg_errors']
        imp_complex = (baseline_errors['complex'] - current_errors['complex']) / (baseline_errors['complex'] + 1e-9)
        imp_simple = (baseline_errors['simple'] - current_errors['simple']) / (baseline_errors['simple'] + 1e-9)
        improvements_complex.append(imp_complex * 100)
        improvements_simple.append(imp_simple * 100)
        
        # 提取p值用于显著性标记
        if name != baseline_model_name and 'stats' in results[name] and f'vs_{baseline_model_name}' in results[name]['stats']:
            p_values_complex.append(results[name]['stats'][f'vs_{baseline_model_name}']['p_value'])
        else:
            p_values_complex.append(1.0) # p值为1.0表示不显著

    # --- 开始绘图 (核心修改部分) ---
    plt.style.use('seaborn-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7)) # 调整尺寸以适应柱状图

    # 设置柱状图的宽度和位置
    x = np.arange(len(depths))  # x轴刻度的位置
    width = 0.35  # 每个柱子的宽度

    # 绘制分组柱状图
    rects1 = ax.bar(x - width/2, improvements_complex, width, 
                    label='Complex Regions', color='firebrick')
    rects2 = ax.bar(x + width/2, improvements_simple, width, 
                    label='Simple Regions', color='cornflowerblue')

    # 在'Complex'区域的柱子顶端添加统计显著性星号
    # for i, p in enumerate(p_values_complex):
    #     if p < 0.05:
    #         # 获取当前柱子的高度，并在此基础上增加一点偏移量
    #         bar_height = rects1[i].get_height()
    #         # 动态调整星号位置，使其始终在柱子上方
    #         offset = 2 if bar_height >= 0 else -4 
    #         ax.text(rects1[i].get_x() + rects1[i].get_width() / 2, bar_height + offset,
    #                 '*', ha='center', va='bottom', color='gold', fontsize=25,
    #                 # 为星号添加一个假的图例条目，只需一次
    #                 label='Statistically Significant (p<0.05)' if 'Statistically Significant (p<0.05)' not in plt.gca().get_legend_handles_labels()[1] else "")

    # --- 设置图表标题和标签 (与原代码类似，但更适用于柱状图) ---
    ax.set_ylabel("Relative MSE Reduction (%)", fontsize=24)
    ax.set_xlabel("Number of Network Layers", fontsize=24)
    ax.set_title(f'Airfoil', fontsize=32)
    
    # 设置x轴刻度标签为模型的深度
    ax.set_xticks(x)
    custom_labels = ['6', '10', '12']
    ax.set_xticklabels(custom_labels)
    
    ax.tick_params(axis='both', which='major', labelsize=24)
    
    # 添加图例
    ax.legend(loc='upper left',fontsize=24)

    # 添加 y=0 的参考线
    ax.axhline(0, color='black', linewidth=0.8, linestyle=':')
    ax.grid(True, which='major', axis='y', linestyle='--', linewidth=0.7)
    
    # 自动调整布局，防止标签重叠
    fig.tight_layout()

    # --- 保存图片 ---
    save_path = os.path.join(output_dir, f'error_vs_depth_barchart_{dataset_name}.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Bar chart saved to: {save_path}")
    plt.close(fig)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize Error vs. Depth curve from evaluation results.")
    # --- 命令行参数保持不变 ---
    parser.add_argument('--results_path', type=str, required=True,
                        help='Path to the .json file generated by evaluate_by_complexity.py.')
    parser.add_argument('--output_dir', type=str, default='./results/plots',
                        help='Directory to save the output plots.')
    
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    print(f"Loading results from: {args.results_path}")
    with open(args.results_path, 'r') as f:
        results_data = json.load(f)
    
    try:
        # 尝试从文件名解析数据集名称
        dataset_name = os.path.basename(args.results_path).split('_')[2].split('.')[0]
    except:
        dataset_name = "Unknown Dataset"

    if results_data:
        # 调用新的柱状图绘制函数
        plot_error_vs_depth_barchart(results_data, args.output_dir, dataset_name, args)
        print("\nPlot has been generated successfully.")
    else:
        print("Error: The results file is empty.")