import os 
import sys 
import torch
import numpy as np 







weight = torch.load(sys.argv[1], map_location='cpu')['state_dict']
density = float(sys.argv[2])

masks = {}
target_weight = {}
allelement = 0
flatten_weight = []
for key in weight.keys():
    if len(weight[key].size()) in [2,4]:
        print(key)
        target_weight[key] = weight[key]
        allelement += weight[key].nelement()
        flatten_weight.append(weight[key].reshape(-1))

# generate masks
flatten_weight = torch.cat(flatten_weight)
num_remain = int(flatten_weight.shape[0]*density)
threshold = flatten_weight[torch.argsort(flatten_weight)[-num_remain]]
print(flatten_weight.shape[0], num_remain, threshold)

for key in target_weight.keys():
    layer_mask = (target_weight[key] > threshold).float()
    masks[key+'_mask'] = layer_mask

torch.save(masks, sys.argv[3])



# path = os.listdir(sys.argv[1])

# for model in path:
#     print(model)
#     masks = torch.load(os.path.join(sys.argv[1], model), map_location='cpu')
#     new_masks = {}
#     for key in masks.keys():
#         new_masks[key+'_mask'] = masks[key]
#     torch.save(new_masks, os.path.join(sys.argv[1], model))





