import numpy as np
import matplotlib.pyplot as plt
import os

def set_plot_style():
    """设置全局绘图样式"""
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['mathtext.fontset'] = 'stix'  # 数学公式字体
    plt.rcParams['xtick.major.size'] = 0  # 删除刻度线
    plt.rcParams['ytick.major.size'] = 0  # 删除刻度线
    plt.rcParams['axes.spines.top'] = False  # 移除上边框
    plt.rcParams['axes.spines.right'] = False  # 移除右边框
    plt.rcParams['axes.spines.left'] = False  # 保留左边框
    plt.rcParams['axes.spines.bottom'] = False  # 保留下边框
    plt.rcParams['axes.facecolor'] = '#f0f0f0'  # 设置非常淡的灰色背景

def load_probability_data(file_path):
    """加载概率数据文件"""
    probs = []
    with open(file_path, 'r') as f:
        lines = f.readlines()[1:]  # 跳过标题行
        for line in lines:
            prob = float(line.strip().split('\t')[1])
            probs.append(prob)
    return np.array(probs)

def plot_probability_comparison():
    """绘制概率对比图"""
    # 设置绘图样式
    set_plot_style()
    
    # 创建正方形图
    plt.figure(figsize=(8, 8))
    
    # 加载数据
    honest_probs = load_probability_data("probability/problem16_action0_probability.txt")
    evdm_probs = load_probability_data("probability/problem16_action4_probability.txt")
    
    # 重采样数据到100个点
    num_points = 100
    honest_x = np.linspace(0, 100, len(honest_probs))
    evdm_x = np.linspace(0, 100, len(evdm_probs))
    
    # 创建新的等间距采样点
    percentages = np.linspace(0, 100, num_points)
    
    # 对两组数据进行插值
    honest_probs_resampled = np.interp(percentages, honest_x, honest_probs)
    evdm_probs_resampled = np.interp(percentages, evdm_x, evdm_probs)
    
    # 设置线型和标记
    line_styles = ['-', '-']  # 实线
    markers = ['o', 'v']  # 不同的标记
    colors = ['#1f77b4', '#9467bd']  # Honest用蓝色，E-VDM*用紫色
    
    # 绘制曲线
    plt.plot(percentages, honest_probs_resampled,
            label='Honest',
            linestyle=line_styles[0],
            marker=markers[0],
            color=colors[0],
            linewidth=2,
            markersize=8,
            markevery=5)  # 每5%显示一个点
            
    plt.plot(percentages, evdm_probs_resampled,
            label='E-VDM*',
            linestyle=line_styles[1],
            marker=markers[1],
            color=colors[1],
            linewidth=2,
            markersize=8,
            markevery=5)  # 每5%显示一个点
    
    # 设置图表属性
    plt.xlabel('Trajectory Progress (%)', fontsize=20)
    plt.ylabel('Probability of real goal state', fontsize=20)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=20)
    plt.tick_params(axis='both', which='major', labelsize=20)
    
    # 设置坐标轴范围
    plt.ylim(-0.05, 1.05)
    plt.xlim(-2, 102)
    
    # 创建输出目录
    os.makedirs("probability_comparison", exist_ok=True)
    
    # 调整坐标轴和刻度线的颜色
    plt.gca().spines['left'].set_color('black')
    plt.gca().spines['bottom'].set_color('black')
    plt.gca().spines['left'].set_linewidth(1.5)
    plt.gca().spines['bottom'].set_linewidth(1.5)
    
    # # 获取当前图形对象并添加边框
    # fig = plt.gcf()
    # fig.patch.set_linewidth(10)  # 设置边框宽度为10
    # fig.patch.set_edgecolor('#666666')  # 设置边框为浅灰色

    # # 保存图表
    # plt.savefig(os.path.join("probability_comparison", "honest_vs_evdm_comparison.pdf"),
    #             bbox_inches='tight', 
    #             format='pdf',
    #             facecolor=fig.get_facecolor(),
    #             edgecolor=fig.get_edgecolor(),
    #             dpi=300)
    # plt.close()
        # 获取当前图形对象并添加边框

    # 保存图表
    plt.savefig(os.path.join("probability_comparison", "honest_vs_evdm_comparison.pdf"),
                bbox_inches='tight', 
                format='pdf')
    plt.close()

if __name__ == "__main__":
    plot_probability_comparison()
