from datasets import load_dataset
from tqdm import tqdm

from ds_src.initialize.datasets.SlimPajama import SlimPajama

classes = {}

# dataset = load_dataset(
#     path='DKYoon/SlimPajama-6B',
#     split='train',
# )

# for data_idx, data in enumerate(tqdm(iterable=dataset)):
#     if data['meta']['redpajama_set_name'] not in classes.keys():
#         classes[data['meta']['redpajama_set_name']] = 1
#     else:
#         classes[data['meta']['redpajama_set_name']] += 1

# print(classes)

dataset = SlimPajama(
    config={},
    dataset_config={
        'path': 'DKYoon/SlimPajama-6B',
        'split': 'train',
        'subset_data_number': {
            'RedPajamaC4': 347859,
            'RedPajamaStackExchange': 31700,
            'RedPajamaCommonCrawl': 199727,
            'RedPajamaGithub': 22742,
            'RedPajamaWikipedia': 28978,
            'RedPajamaArXiv': 1621,
            'RedPajamaBook': 205,
        },
    },
)

for data_idx, data in enumerate(tqdm(iterable=dataset)):
    if data['source'] not in classes.keys():
        classes[data['source']] = 1
    else:
        classes[data['source']] += 1

print(classes)

for one_class in classes.keys():
    print(f'{one_class}: {(classes[one_class] / len(dataset)): .5f}')
    print(f'{one_class}: {632832 * classes[one_class] / len(dataset): .5f}')
