import pdb
import pickle
import argparse
from random import random
import torch
import numpy as np
import os

def get_prune_idx(distance, low, high):
    sorted_idx = distance.numpy().argsort()
    low_idx = round(distance.shape[0] * low)
    high_idx = round(distance.shape[0] * high)
    
    return np.concatenate((sorted_idx[:low_idx], sorted_idx[high_idx:]))

def main(args, rate):
    with open(os.path.join(args.input_dir, "distance.bin"), "rb") as f:
        dic = pickle.load(f)
        
    distance = torch.from_numpy(dic["distance"])
    print("distance size: ", distance.shape[0])
    low = 0.5 - rate/2
    high = 0.5 + rate/2
    ids = get_prune_idx(distance, low, high)
    return ids


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", default="./")
    parser.add_argument('--input_dir', default="./")
    arg = parser.parse_args()
    
    rates = [0.2, 0.3, 0.4]
    for rate in rates:
        id = main(arg, rate)
        with open(os.path.join(arg.save_dir,f"prune-{rate}"), "wb") as f:
            pickle.dump(id, f)