import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os



dataset = 'cub'
dataset = 'cub'
model = 'icarl'
imagename_list = os.listdir('/home/admin/data/cub/train/')
for label in imagename_list:
    if '028' in label:
        # print(imagename_list)
        # label='091.Mockingbird'
        # cub_icarl9_093.Clark_Nutcracker_gb
        image_paths = []
        for x in range(5):
            image_paths.append(f'./output/{dataset}_{model}{x+5}_{label}_cam.jpg')
        for x in range(5):
            image_paths.append(f'./output/{dataset}_{model}{x+5}_{label}_dnm_cam.jpg')

        # for x in range(5):
        #     image_paths.append(f'./output/gradcam_{x}_cam.jpg')
        # for x in range(5):
        #     image_paths.append(f'./output/gradcam_{x}_dnm_cam.jpg')
        
        n_images = len(image_paths)
        n_cols = 5
        n_rows = (n_images + n_cols - 1) // n_cols

        
        # plt.rcParams['text.usetex'] = True
        plt.rcParams['font.family'] = 'Times New Roman'

        # fig, axes = plt.subplots(1, 9, figsize=(9, 1))
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3.3*n_rows))

        for idx, image_path in enumerate(image_paths):
            
            row = idx // n_cols
            col = idx % n_cols
            # ax = axes[idx] #if n_rows > 1 else axes[col]
            ax = axes[row, col] if n_rows > 1 else axes[col]
            
            # image_path = f'./output/gradcam_{idx}_dnm_cam.jpg'
            print(image_path)
            img = mpimg.imread(image_path)
            ax.imshow(img)
            # ax.axis('off')
            ax.set_xticks([])
            ax.set_yticks([])
            if idx >=5:
                # $\\tau_{{{tau_val}}}$
                ax.set_xlabel(f"$T_{{{idx+1}}}$", fontsize=28)
            # if idx ==0:
            #     ax.set_ylabel('dd')
        axes[0, 0].set_ylabel("iCaRL",fontsize=28)
        axes[1, 0].set_ylabel("iCaRL w/ DeL", fontsize=28)

        plt.tight_layout()
        plt.savefig(f'composite_{label}.png', dpi=300)
    