
# Select sample i from class c
def get_sample_from_class(dataset, class_label, sample_index):
    indices = [i for i, (img, label) in enumerate(dataset) if label == class_label]
    if sample_index >= len(indices):
        raise IndexError("Sample index out of range")
    sample = dataset[indices[sample_index]][0]
    return sample