import numpy as np
import os
from matplotlib import pyplot as plt
from pdb import set_trace as bp

folder_path = 'explanation_masks'
cls = 3
save_path = 'masks_viz'

masks_names = os.listdir('explanation_masks/'+str(cls))
masks_paths = [folder_path+'/'+str(cls)+'/'+i for i in masks_names]

# for i, m in enumerate(masks_paths):
#     mask = np.load(m)
#     mean = np.mean(mask)
#     print(f'mean: {mean}')
#     plt.imsave(save_path+'/'+str(i)+'.png', mask, cmap='gray')
#     mask[mask<=mean] = 0
#     mask[mask>mean] = 1
#     plt.imsave(save_path+'/'+str(i)+'_thr.png', mask, cmap='gray')
for i in range(len(masks_paths)):
    m = masks_paths[i]
    if m[-8:-4] == 'annt':
        mask = np.load(m)
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'.png', mask, cmap='gray')
    else:
        mask = np.load(m)
        mean = np.mean(mask)
        print(f'mean: {mean}')
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'.png', mask, cmap='gray')
        mask[mask<=mean] = 0
        mask[mask>mean] = 1
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'_thr.png', mask, cmap='gray')
        