import torch
# Remove unlabeled data (labeled as -1)

def data_verification(dataset, num_class=-1, debug=False):
    if num_class != -1:
        for s in dataset.keys():
            exclude = []
            for idx in range(len(dataset[s])):
                if int(dataset[s][idx]['label']) not in list(range(num_class)):
                    exclude += [idx]
            
            dataset[s] = dataset[s].select(i for i in range(len(dataset[s])) if i not in set(exclude))
    
    print("Reduce the dataset size... (Remove wrong samples...)")

    if debug:    
        for s in dataset.keys():
            include = torch.randperm(len(dataset[s]))[:10]
            dataset[s] = dataset[s].select(i for i in set(include))

    return dataset
            

