import torch
import numpy as np


def gen_imbalanced_data(
    data, targets, num_class, imb_type="exp", imb_factor=0.01, is_cub=False
):
    if imb_type == "original": return data, torch.Tensor(targets).long()

    img_max = len(data) / num_class
    img_num_per_cls = get_img_num_per_cls(num_class, imb_type, imb_factor, img_max)

    new_data = []
    new_targets = []
    targets_np = np.array(targets, dtype=np.int64)
    classes = np.unique(targets_np)
    num_per_cls_dict = dict()
    for the_class, the_img_num in zip(classes, img_num_per_cls):
        num_per_cls_dict[the_class] = the_img_num
        idx = np.where(targets_np == the_class)[0]
        np.random.shuffle(idx)
        selec_idx = idx[:the_img_num]
        the_img_num = len(selec_idx)
        if is_cub:
            new_data += [data[t] for t in selec_idx]
        else:
            new_data.append(data[selec_idx, ...])
        new_targets.extend(
            [
                the_class,
            ]
            * the_img_num
        )
    if not is_cub:
        new_data = np.vstack(new_data)

    new_targets = torch.Tensor(new_targets).long()
    return new_data, new_targets


def get_img_num_per_cls(cls_num, imb_type, imb_factor, img_max):
    img_num_per_cls = []
    if imb_type == "exp":
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == "expr":
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor ** ((cls_num - cls_idx) / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == "step":
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max))
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max * imb_factor))
    else:
        raise NotImplementedError("You have chosen an unsupported imb type.")
    return img_num_per_cls
