import torch
import os.path

def grab_activations_topk_them(acts_path, output_path, descriptor = '', k=100,):

    acts = torch.load(acts_path)

    top_acts, top_indices = torch.topk(acts,k,dim=1)
    #print(top_acts.shape)
    
    torch.save(top_acts, output_path+f'/{descriptor}top_activations.pt')
    torch.save(top_indices, output_path+f'/{descriptor}top_indices.pt')
    print(output_path+f'/{descriptor}top_activations.pt')
    print(output_path+f'/{descriptor}top_indices.pt')



init_paths = [
    '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_10_initial_model_activations.pt',
    '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_10_initial_model_activations.pt',
    '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_10_initial_model_activations.pt',

    # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_0_initial_model_activations.pt',
    # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_3_initial_model_activations.pt',
    # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_8_initial_model_activations.pt',
    # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_10_initial_model_activations.pt',
    # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/features_10_initial_model_activations.pt',
    # '/home/a_fuller/projects/Attacking-Interpretability/efficientnet/features_7_0_block_3_0_initial_model_activations.pt'
]
descriptors = [
    'f0_init_',
    'f3_init_',
    'f8_init_',
    'f10_init_',
    'f10_c0_init',
    'f7_efficient_net_init'

]
final_paths = [
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_c0_only/results/features_10_final_model_activations.pt",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c1_top10_to_zero/results/features_10_final_model_activations.pt",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c2_top10_to_zero/results/features_10_final_model_activations.pt",

                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f0_dataset_top10_to_zero/results/features_0_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f3_dataset_top10_to_zero/results/features_3_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f8_dataset_top10_to_zero/results/features_8_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_top10_to_zero/results/features_10_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_refs_to_top/results/features_10_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_c0_only/results/features_10_final_model_activations.pt',
                # '/home/a_fuller/projects/Attacking-Interpretability/efficientnet/v12/f7_0_b3_0_dataset_top10_to_zero/results/features_7_0_block_3_0_final_model_activations.pt',

]
print(os.path.split(init_paths[0])[0])
for my_path, descriptor in zip(init_paths, descriptors):
    grab_activations_topk_them(my_path, os.path.split(my_path)[0], descriptor = descriptor)
descriptors = [
    "Conv5 C0 Only".replace(' ', '_').lower(),
    "Conv 5 C1 Only".replace(' ', '_').lower(),
    "Conv 5 C2 Only".replace(' ', '_').lower(),
    # 'f0_final_',
    # 'f3_final_',
    # 'f8_final_',
    # 'f10_final_',
    # 'f10_final_refs_to_top_',
    # 'f10_c0_final',
    # 'f7_efficient_net_final'
]
for my_path, descriptor in zip(final_paths, descriptors):
    grab_activations_topk_them(my_path, os.path.split(my_path)[0], descriptor = 'final_')