import numpy as np
import torch

__all__ = [
    "drop_disagreement"
]

def drop_disagreement(dataset):
    if isinstance(dataset, torch.utils.data.dataset.Subset):
        dataset.indices = np.array(dataset.indices)
        dataset.indices = dataset.indices[
            dataset.dataset.clean_labels[dataset.indices] ==  dataset.dataset.raw_labels[dataset.indices]
        ]
    
    print('The size of filtered dataset is ', len(dataset))