import os
from tqdm import tqdm
import pickle
import json

mode = 'total' # 'class' or 'total'
only_mean_path = f'./results/shap_statistics/shap_statistics_{mode}_mean.json'
only_std_path = f'./results/shap_statistics/shap_statistics_{mode}_std.json'

shap_dir = './results/shap'
shap_list = os.listdir(shap_dir)
shap_list = sorted(shap_list)

shap_sub_list = []
for sub in shap_list:
    if os.path.isdir(os.path.join(shap_dir, sub)):
        sub_list = os.listdir(os.path.join(shap_dir, sub))
        for sub_dir in sub_list:
            subsub_list = os.listdir(os.path.join(shap_dir, sub, sub_dir))
            for subsub_dir in subsub_list:
                shap_sub_list.append(os.path.join(shap_dir, sub, sub_dir, subsub_dir))

shap_path = [os.path.join(shap, 'attack_no.pkl') for shap in shap_sub_list]
shap_path = sorted(shap_path)
results = {}
only_mean_results = {}
only_std_results = {}
for path in tqdm(shap_path, desc='Load shap pkl'):
    if os.path.exists(path) == True:
        me = path.split('/')[-2]
        ar = path.split('/')[-3]
        net = path.split('/')[-4]
        key = f'{net}_{ar}_{me}'
        with open(path, 'rb') as f:
            shap_data = pickle.load(f)
        
        shap_data = shap_data.mean(axis=0)
        mean, val, std = shap_data.mean(), shap_data.var(), shap_data.std()

        results[key] = {'mean': f'{mean:.6f}', 'val': f'{val:.6f}', 'std': f'{std:.6f}'}
        only_std_results[key] = f'{std:.6f}'
        only_mean_results[key] = f'{mean:.6f}'

with open(only_mean_path, 'w') as f:
    json.dump(only_mean_results, f, indent=4)
      
with open(only_std_path, 'w') as f:
    json.dump(only_std_results, f, indent=4)

print(f'Saved at {only_mean_path}')
print(f'Saved at {only_std_path}')

