import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# step to show
STEP = 25

save_pdf = True
num_steps = 50

beta = np.linspace(
        0.0001 ** 0.5, 0.5 ** 0.5, num_steps
    ) ** 2

alpha_hat = 1 - beta
alpha = np.cumprod(alpha_hat)

coefficent = -1 / ((1 - alpha)**0.5)

coeff1 = 1 / alpha_hat ** 0.5
coeff2 = (1 - alpha_hat) / (1 - alpha) ** 0.5
coefficent1 = coeff1 * coeff2
################################################

csdi_gt= np.load("save_score_based_function_Best_for_CSDI_Synthetic.npz")['data']
gt = np.load("save_score_based_function_Best_for_Ground_truth_Synthetic.npz")['data']
vae_gt = np.load("save_score_based_function_Best_for_Latent_diffusion_Synthetic.npz")['data']

csdi_gt = np.mean(csdi_gt, axis=0)
gt = np.mean(gt, axis=0)
vae_gt = np.mean(vae_gt, axis=0)

csdi_gt = csdi_gt.transpose(0,2,1,3)
gt = gt.transpose(0,2,1,3)
vae_gt = vae_gt.transpose(0,2,1,3)

csdi_gt = csdi_gt.reshape(csdi_gt.shape[0], 20, 20)
vae_gt = vae_gt.reshape(vae_gt.shape[0], 20, 20)
gt = gt.reshape(gt.shape[0], 20, 20)

###############
mae_csdi = np.abs(np.mean(np.abs(csdi_gt - gt), axis=(1, 2)) * coefficent1[::-1])
mae_vae = np.abs( np.mean(np.abs(vae_gt - gt), axis=(1, 2)) * coefficent1[::-1])

fig = plt.figure()
fig = plt.figure(figsize=(18,10))
ax1 = fig.add_subplot(111)
sns.set_style("whitegrid")

#
ax1.set_xticklabels([0, '0', '10', '20', '30', '40', '50'], fontsize=50)  # 设置y轴图例为空值
ax1.set_yticklabels(['', 0, ' ', '1.0', '', '2.0'], fontsize=50)  # 设置y轴图例为空值

plt.plot(np.cumsum(mae_csdi), label='CSDI ', marker='d',  markersize=10)
plt.plot(np.cumsum(mae_vae), label='HSGM', marker='h',  markersize=10)

plt.xlabel("Diffusion reverse step ", fontsize=60)
# plt.ylabel("Mean Absolute Error (MAE) of bias", fontsize=30)
plt.title("Accumulation bias in the reverse sampling stage", fontsize=55)
plt.legend(fontsize=40)
plt.grid(True)
plt.tight_layout()
plt.show()
if save_pdf:
    fig.savefig('sythetic_dataset_accumulation_deviation.pdf')

# #################################################################

mae_csdi1 = np.abs(np.mean(np.abs(csdi_gt - gt), axis=(1, 2)) * coefficent[::-1])
mae_vae1 = np.abs( np.mean(np.abs(vae_gt - gt), axis=(1, 2)) * coefficent[::-1])

fig = plt.figure()
fig = plt.figure(figsize=(18,10))
ax1 = fig.add_subplot(111)
sns.set_style("whitegrid")
ax1.set_xticklabels([0, '0', '10', '20', '30', '40', '50'], fontsize=50)  # 设置y轴图例为空值
ax1.set_yticklabels(['', 0, ' ', '5.0', '', '10.0', '', '15.0', '', '20.0'], fontsize=50)  # 设置y轴图例为空值

plt.plot(mae_csdi1, label='CSDI ', marker='X',  markersize=10)
plt.plot(mae_vae1, label='HSGM', marker='o',  markersize=10)

plt.xlabel("Diffusion reverse step ", fontsize=60)

plt.title("MAE of score function bias in reverse stage", fontsize=55)
plt.legend(fontsize=40)
plt.grid(True)
plt.tight_layout()
plt.show()
if save_pdf:
 fig.savefig('sythetic_dataset_deviation.pdf')
#############################################################

scale = 0.5
sns.set_style("whitegrid")
cmaps = ['RdBu_r', 'viridis']

# -----------------------------
# CSDI
fig1 = plt.figure(figsize=(6, 6))
im0 = plt.imshow(csdi_gt[STEP, :, :], vmin=0.0, vmax=scale, cmap=cmaps[0])
plt.title("CSDI", fontsize=20)
plt.xlabel("4 x Sensors evaluate time length", fontsize=16)
plt.ylabel("Evaluate time length", fontsize=16)
plt.colorbar(im0)
plt.tight_layout()
plt.grid(False)
if save_pdf:
    fig1.savefig(f'sythetic_dataset_heatmap_CSDI_{STEP}.pdf', bbox_inches='tight')
plt.show()

# -----------------------------
# Ground Truth
fig2 = plt.figure(figsize=(6, 6))
im1 = plt.imshow(gt[STEP, :, :], vmin=0.0, vmax=scale, cmap=cmaps[0])
plt.title("Ground Truth", fontsize=20)
plt.xlabel("4 x Sensors evaluate time length", fontsize=16)
plt.ylabel("Evaluate time length", fontsize=16)
plt.colorbar(im1)
plt.tight_layout()
plt.grid(False)
if save_pdf:
    fig2.savefig(f'sythetic_dataset_heatmap_Ground_truth_{STEP}.pdf', bbox_inches='tight')
plt.show()

# -----------------------------
# HSGM
fig3 = plt.figure(figsize=(6, 6))
im2 = plt.imshow(vae_gt[STEP, :, :], vmin=0.0, vmax=scale, cmap=cmaps[0])
plt.title("HSGM", fontsize=20)
plt.xlabel("4 x Sensors evaluate time length", fontsize=16)
plt.ylabel("Evaluate time length", fontsize=16)
plt.colorbar(im2)
plt.tight_layout()
plt.grid(False)
if save_pdf:
    fig3.savefig(f'sythetic_dataset_heatmap_HSGM_{STEP}.pdf', bbox_inches='tight')
plt.show()
