import numpy as np
import os
from matplotlib import pyplot as plt
from pdb import set_trace as bp

folder_path = 'explanation_masks'
cls = 3
save_path = 'masks_viz'

masks_names = os.listdir('explanation_masks/'+str(cls))
masks_paths = [folder_path+'/'+str(cls)+'/'+i for i in masks_names]

for i in range(len(masks_paths)):
    m = masks_paths[i]
    if m[-8:-4] == 'annt':
        mask = np.load(m)
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'.png', mask, cmap='gray')
    else:
        mask = np.load(m)
        mean = np.mean(mask)
        print(f'mean: {mean}')
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'.png', mask, cmap='gray')
        mask[mask<=mean] = 0
        mask[mask>mean] = 1
        plt.imsave(save_path+'/'+m.split('/')[-1][:-4]+'_thr.png', mask, cmap='gray')
        