# 读取.pkl文件的东西
# 根据轨迹画图
import pickle
import matplotlib.pyplot as plt
import numpy as np
def plot_trace():
    # 有三个pkl文件
    with open('PSR-0.pkl', 'rb') as f:
        data1 = pickle.load(f)
    with open('PSR-1.pkl', 'rb') as f:
        data2 = pickle.load(f)
    with open('PSR-2.pkl', 'rb') as f:
        data3 = pickle.load(f)
    data1 = data1[5]  # 假设数据在第5个键下
    data2 = data2[5]  # 假设数据在第5个键下
    data3 = data3[5]  # 假设数据在第
    # 假设每个数据都是一个列表，包含多个数值
    # 在一张图上画出三条曲线
    # 改变图中的字体大小
    plt.rcParams.update({'font.size': 20})  # 设置全局字体大小
    plt.figure(figsize=(10, 6))
    plt.plot(data1, label='MSE', color='blue')
    plt.plot(data2, label='MSE+PDE', color='orange')
    plt.plot(data3, label='MSE+MSEDI', color='green')
    plt.xlabel('Generation')
    plt.ylabel('PDE Residual Error')
    plt.title('PDE Residual Error vs Generation')
    plt.legend()
    plt.grid(True)
    plt.savefig('trace_plot.png')  # 保存图像
    plt.show()  # 显示图像

if __name__ == "__main__":
    plot_trace()
    
