
def simi_feat_batch(cfg, dataset):
    """ Construct the set of data that are likely to be corrupted.
    """

    # Build Feature Clusters --------------------------------------
    num_classes = cfg.num_classes

    if cfg.hoc_cfg.balance:
        sample_size = int(min(cfg.noisy_prior) * num_classes * 0.9)
    else:
        sample_size = int(len(dataset) * 0.9)

    if cfg.hoc_cfg is not None and cfg.hoc_cfg.sample_size:
        sample_size = np.min((cfg.hoc_cfg.sample_size, sample_size))

    try: 
        if cfg.hoc_cfg.balance:
            idx = []
            for i in range(num_classes):
                idx += np.random.choice(range(len(dataset))[dataset.label == i], sample_size//num_classes, replace=False).tolist()
            idx = np.asarray(idx)
            print(f"balanced sampling: {cfg.hoc_cfg.balance}. Length = {len(idx)}")
        else:
            idx = np.random.choice(range(len(dataset)), sample_size, replace=False)
            print(f"random sampling: {cfg.hoc_cfg.balance}. Length = {len(idx)}")
    except:
        print("cfg.hoc_cfg.balance not exists")


    knn_labels_cnt = count_knn_distribution(
        cfg, dataset=dataset, sample=idx, k=cfg.detect_cfg.k, norm='l2')

    score = get_score(knn_labels_cnt, torch.tensor(dataset.label[idx]))
    score_np = score.cpu().numpy()
    sel_idx = dataset.index[idx]  # raw index

    label_pred = np.argmax(knn_labels_cnt.cpu().numpy(), axis=1).reshape(-1)
    if cfg.detect_cfg.method == 'mv':
        # test majority voting
        # print(f'Use MV')
        sel_true_false = label_pred != dataset.label[idx]
        sel_noisy = (sel_idx[sel_true_false]).tolist()
        suggest_label = label_pred[sel_true_false].tolist()
    elif cfg.detect_cfg.method == 'rank':
        # print(f'Use ranking')

        sel_noisy = []
        suggest_label = []
        for sel_class in range(num_classes):
            thre_noise_rate_per_class = 1 - \
                min(1.0 * cfg.T_given_noisy[sel_class][sel_class], 1.0)
            # clip the outliers
            if thre_noise_rate_per_class >= 1.0:
                thre_noise_rate_per_class = 0.95
            elif thre_noise_rate_per_class <= 0.0:
                thre_noise_rate_per_class = 0.05
            sel_labels = dataset.label[idx] == sel_class
            thre = np.percentile(
                score_np[sel_labels], 100 * (1 - thre_noise_rate_per_class))

            indicator_all_tail = (score_np >= thre) * (sel_labels)
            sel_noisy += sel_idx[indicator_all_tail].tolist()
            suggest_label += label_pred[indicator_all_tail].tolist()
    else:
        raise NameError('Undefined method')

    # raw index, raw index, suggested true label
    return sel_noisy, sel_idx, suggest_label
