
import numpy as np
import pandas as pd

from collections import Counter

def classificationdatasets(task, imbalance_type, imbalance_ratio):

    if task == 'agnews':
        raw_dat_train = pd.read_json("agnews_train.jsonl", orient="records", lines=True)
        raw_dat_test = pd.read_json("agnews_test.jsonl", orient="records", lines=True)
        class_num = 4
    
    else:
        print("ERROR DATALOADER")


    img_num_per_cls = get_img_num_per_cls(raw_dat_train['label'], class_num, imbalance_type, imbalance_ratio)
    text, label = gen_imbalanced_data(raw_dat_train['text'], raw_dat_train['label'], img_num_per_cls)
    dat_train = pd.DataFrame({"text": text, 'label': label})

    return dat_train, raw_dat_test



def get_img_num_per_cls(data, cls_num, imb_type, imb_factor):
    img_max = min(Counter(data).values())
    img_num_per_cls = []
    if imb_type == 'exp':
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == 'step':
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max))
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max * imb_factor))
    else:
        img_num_per_cls.extend([int(img_max)] * cls_num)
    return img_num_per_cls
    

def gen_imbalanced_data(text, label, img_num_per_cls):
    new_data = []
    new_targets = []
    targets_np = np.array(label, dtype=np.int64)
    classes = np.unique(targets_np)
        
    num_per_cls_dict = dict()
    for the_class, the_img_num in zip(classes, img_num_per_cls):
        num_per_cls_dict[the_class] = the_img_num
        idx = np.where(targets_np == the_class)[0]
        np.random.shuffle(idx)
        selec_idx = idx[:the_img_num]
        new_data.append([text[i] for i in selec_idx])
        new_targets.extend([label[i] for i in selec_idx])

    return sum(new_data,[]), new_targets