import sys
import numpy as np
from datasets import load_from_disk
from collections import Counter

np.random.seed(0)

dataset = sys.argv[1]
data = load_from_disk(dataset)
data.remove_columns_('ins_weight')
train_data = data['train']
train_data.shuffle(seed=0)
min_count = min(np.bincount(train_data['entropy_class']))

easy_ids = [idx for idx, x in enumerate(train_data) if x['entropy_class'] == 0]
med_ids = [idx for idx, x in enumerate(train_data) if x['entropy_class'] == 1]
hard_ids = [idx for idx, x in enumerate(train_data) if x['entropy_class'] == 2]

easy_ids = easy_ids[:min_count]
med_ids = med_ids[:min_count]
hard_ids = hard_ids[:min_count]
ids = easy_ids + med_ids + hard_ids 

data['train'] = train_data.select(ids)

data.save_to_disk(dataset + '_balanced')

"""
target_ids = [[idx for idx in easy_ids if train_data[idx]['label'] == t] for t in [0,1,2]]\
        + [[idx for idx in med_ids if train_data[idx]['label'] == t] for t in [0,1,2]]\
        + [[idx for idx in hard_ids if train_data[idx]['label'] == t] for t in [0,1,2]]
min_count = min([len(x) for x in target_ids])
target_ids = [x[:min_count] for x in target_ids]
target_ids = [x for l in target_ids for x in l]

data['train'] = train_data.select(target_ids)
data.save_to_disk('data/snli_super_balanced')
"""
