import pickle
import os
import numpy as np

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--target_size', type=int, required=True)
parser.add_argument('--pruned_ckpt', type=str, required=True)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--save_file_name', type=str, required=True)
parser.add_argument('--specific_classes', type=int, nargs='+')

def main():
    args = parser.parse_args()
    with open(args.pruned_ckpt, 'rb') as f:
        all_accs, all_paths = pickle.load(f)
    num_classes = len(all_accs)
    num_per_class_to_keep = args.target_size // (num_classes * 2)
    acc_prune_list = []
    conf_prune_list = []
    random_prune_list = []
    for cls_idx in range(num_classes):
        if args.specific_classes is not None and cls_idx not in args.specific_classes:
            continue
        accs = all_accs[cls_idx]
        paths = all_paths[cls_idx]
        # split into 2 for source / target
        size = len(accs) // 2
        for i in range(2):
            curr_accs = accs[i * size:(i + 1) * size]
            curr_paths = paths[i * size:(i + 1) * size]
            acc_path_list = list(zip(curr_accs, curr_paths))

            acc_result = sorted(acc_path_list)
            acc_prune_list.extend([item[1] for item in acc_result[-num_per_class_to_keep:]])

            conf_result = sorted([(abs(item[0] - 0.5), item[1]) for item in acc_path_list])
            conf_prune_list.extend([item[1] for item in conf_result[-num_per_class_to_keep:]])

            idxes = np.random.permutation(len(acc_path_list))
            random_result = [acc_path_list[i] for i in idxes]
            random_prune_list.extend([item[1] for item in random_result[-num_per_class_to_keep:]])
    if args.specific_classes is not None:
        args.save_file_name += '_' + ','.join(list(map(str, args.specific_classes)))
    with open(os.path.join(args.save_dir, f'{args.save_file_name}_acc'), 'wb') as f:
        pickle.dump(acc_prune_list, f)
    with open(os.path.join(args.save_dir, f'{args.save_file_name}_conf'), 'wb') as f:
        pickle.dump(conf_prune_list, f)
    with open(os.path.join(args.save_dir, f'{args.save_file_name}_random'), 'wb') as f:
        pickle.dump(random_prune_list, f)

if __name__ == '__main__':
    main()
