def get_img_num_per_cls(img_max, num_class, imb_type, imb_factor):
    img_num_per_cls = []
    if imb_type == "exp":
        for cls_idx in range(num_class):
            num = img_max * (imb_factor ** (cls_idx / (num_class - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == "step":
        for cls_idx in range(num_class // 2):
            img_num_per_cls.append(int(img_max))
        for cls_idx in range(num_class // 2):
            img_num_per_cls.append(int(img_max * imb_factor))
    else:
        img_num_per_cls.extend([int(img_max)] * num_class)
    return img_num_per_cls
    