import matplotlib
matplotlib.use('TkAgg')
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom


bpp_tensor = np.load('./experiments/latent/q4/bpp/kodim01.npy',allow_pickle=True)
means_tensor = np.load('./experiments/latent/q4/hyper_means/kodim01.npy',allow_pickle=True)
scales_tensor = np.load('./experiments/latent/q4/hyper_scales/kodim01.npy',allow_pickle=True)
y_tensor = np.load('./experiments/latent/q4/y/kodim01.npy',allow_pickle=True)

bpp_list = [bpp_tensor[0,i,:,:] for i in range(bpp_tensor.shape[1])]
means_list = [means_tensor[0,i,:,:] for i in range(means_tensor.shape[1])]
scales_list = [scales_tensor[0,i,:,:] for i in range(scales_tensor.shape[1])]
y_list = [y_tensor[0,i,:,:] for i in range(y_tensor.shape[1])]

bpp_ave = bpp_tensor.mean(axis=1).squeeze(0)
means_ave = means_tensor.mean(axis=1).squeeze(0)
scales_ave = scales_tensor.mean(axis=1).squeeze(0)
y_ave = y_tensor.mean(axis=1).squeeze(0)

# 设置噪声的标准差
'''# 设置高斯噪声的标准差
noise_std = 0.1

noise = np.random.normal(0, noise_std, y_ave.shape)  # 生成高斯噪声'''

y_upsampled = zoom(y_ave, zoom=2)
plt.figure(figsize=(10, 8))
im = plt.imshow(y_upsampled*100, cmap="cividis")
cbar = plt.colorbar(im,shrink=0.5)
#plt.title(f"Feature Map bpp_ave",fontsize=20)
plt.axis('off')
plt.savefig('D:/ResearchingRoad/ICLR2025/Attention/Figures/attn_y1.pdf', bbox_inches='tight',dpi=350)
'''for i, feature_map_data in enumerate(scales_list):
    plt.subplot(16, 20, i + 1)
    plt.imshow(feature_map_data, cmap="viridis")
    plt.title(f"Feature Map {i + 1}",fontsize=5)
    plt.axis('off')'''
plt.show()