import numpy as np
import pdb
import cv2

import matplotlib.pyplot as plt

data = np.load('./Polymnist/5_views.npy',allow_pickle=True).item()

test_data = data['test']

test_label = data['test_label']


# for i in range(len(test_data)):
#     data_i = test_data[i][0].reshape(3,28,28)*255
#     #data_i = (data_i * 255).astype(np.uint8)

#     # 交换数据的轴，将通道数置于最后
#     data_i = np.moveaxis(data_i, 0, -1)
#     cv2.imwrite('./Polymnist/output_image_{}.jpg'.format(i), data_i)
#     print(test_label[0])
#     #pdb.set_trace()
image_list = []

for i in range(10):
    row_index = np.where(test_label == i)[0][0]
    for j in range(5):
        data_i = test_data[j][row_index].reshape(3,28,28)
        image_list.append(data_i)

        


# 创建一个10行5列的子图布局
fig, axes = plt.subplots(10, 5, figsize=(10, 20))

# 遍历所有子图，并将图片数据显示在每个子图中
for i, ax in enumerate(axes.flat):
    image = image_list[i].transpose(1, 2, 0)  # 调整数组维度顺序
    ax.imshow(image)
    ax.axis('off')

# 调整子图之间的间距
plt.tight_layout()

plt.savefig('./Polymnist/show_polymnist.png',dpi=300)