import pickle
import os
import numpy as np

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--target_size', type=int, required=True)
parser.add_argument('--to_prune', choices=['conf', 'acc', 'random'], required=True)
parser.add_argument('--pruned_ckpt', type=str, required=True)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--save_file_name', type=str, required=True)

def main():
    args = parser.parse_args()
    # load previous accuracy results
    with open(args.pruned_ckpt, 'rb') as f:
        accs, paths = pickle.load(f)
    # based on what to prune, sort the paths
    if args.to_prune == 'acc':
        sorted_result = sorted(list(zip(accs, paths)))
    elif args.to_prune == 'conf':
        sorted_result = sorted([(abs(a - 0.5), p) for (a, p) in zip(accs, paths)])
    elif args.to_prune == 'random':
        idxes = np.random.permutation(len(accs))
        sorted_result = [(accs[i], paths[i]) for i in idxes]
    else:
        raise NotImplementedError()
    prune_paths = [item[1] for item in sorted_result[-args.target_size:]]
    with open(os.path.join(args.save_dir, args.save_file_name), 'wb') as f:
        pickle.dump(prune_paths, f)

if __name__ == '__main__':
    main()
