import numpy as np
import matplotlib.pyplot as plt 

# 加载数据
data = np.load('outputs/pre_exp/@20240520-090057/data.npz')

# 解包数据
timesteps = data['timesteps']
diff_pretrain_mean = data['diff_pretrain_mean']
diff_pretrain_std = data['diff_pretrain_std']
diff_gen_mean = data['diff_gen_mean']
diff_gen_std = data['diff_gen_std']

# 使用这些数据...
bar_width = 0.35  # 确保这个宽度小于sigma间的最小差值

# x = np.array(sigma)
x = np.arange(len(timesteps))
plt.bar(x, diff_pretrain_mean, width=bar_width, yerr=diff_pretrain_std, label="Real image", linewidth=2, capsize=5)
# xticks + width，表示的是X轴所有标签第二个柱子的起始位置
plt.bar(x + bar_width, diff_gen_mean, width=bar_width, yerr=diff_gen_std, label="Generated image", linewidth=2, capsize=5)
# 设置x轴刻度
plt.xticks(x + bar_width /2, timesteps)
# x_len = np.array(x)+0.15
# label = [0.32, 0.5, 1, 2, 3, 5, 10, 20, 40]
# plt.xticks(x_len, label,fontsize = 10)
plt.yscale("log") #! 'linear', 'log', 'symlog', 'asinh', 'logit', 'function', 'functionlog'
plt.legend(loc='best')
plt.xlabel('Timesteps')
plt.ylabel('Mse loss of STF and Unet prediction')
plt.savefig('2.png')
print("done!")