import argparse
import torch
from collections import Counter, OrderedDict

import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--label", required=True)
parser.add_argument("--num-classes", type=int, required=True)
args = parser.parse_args()


def save_filter_tail_class(label: str, filter_num_examples: int):

    np_label = np.load(label)
    label_counter: Counter = Counter(np_label)
    list_tail_class = []
    for i, (idx_class, num_examples) in enumerate(label_counter.most_common()):
        if num_examples < filter_num_examples:
            list_tail_class.append(idx_class)
    np_tail_class = np.array(list_tail_class)
    np.save(f"{label}_tailclass_lessthan_{filter_num_examples}.npy", np_tail_class)

    class2list = OrderedDict()
    for i, (idx_class, num_examples) in enumerate(label_counter.most_common()):
        class2list[idx_class] = []

    for idx, idx_class in enumerate(np_label):
        class2list[idx_class].append(idx)

    torch.save(class2list, f"{label}_class2list.pt")


if __name__ == "__main__":
    save_filter_tail_class(args.label, 4)
