import os
import json
import pickle
import argparse
import numpy as np

from tqdm import tqdm
from utils import args_parser
from data import create_dataloaders


def get_balance_list(args):

    print("Creating pretraining dataloader...")
    train_dataloader1 = create_dataloaders(
        data_dir = os.path.join(args.data_dir,"NomicEmbedFullDataset"), 
        folder_name = args.folder_name,
        do_pretrain = True,
        batch_size = args.batch_size,
        num_workers = args.num_workers,
        dataset_info = json.load(open(os.path.join(args.dataset_info_path, "pretrain_dataset_info_orig.json"))),
        load_hash = False,
    )
    print("Creating finetuning dataloader...")
    train_dataloader2 = create_dataloaders(
        data_dir = os.path.join(args.data_dir,"NomicEmbedFTDataset"),
        folder_name = args.folder_name,
        batch_size = args.batch_size,
        num_workers = args.num_workers,
        dataset_info = json.load(open(os.path.join(args.dataset_info_path, "finetune_dataset_info_orig.json"))),
        load_hash = False,
    )
    vocab_size = args.vocab_size

    print("Computing token frequency...")
    frequency = np.zeros(vocab_size, dtype=np.int64)
    for item in tqdm(train_dataloader1):
        query_ids = item["query_input_ids"].reshape(-1)
        key_ids = item["passage_input_ids"].reshape(-1)
        item_ids = np.concatenate((query_ids, key_ids))
        np.add.at(frequency, item_ids, 1)
    for item in tqdm(train_dataloader2):
        query_ids = item["query_input_ids"].reshape(-1)
        key_ids = item["passage_input_ids"].reshape(-1)
        item_ids = np.concatenate((query_ids, key_ids))
        np.add.at(frequency, item_ids, 1)

    frequency_sorted = np.sort(frequency)[::-1]
    frequency_ind = np.argsort(frequency)[::-1]

    print("Computing balanced hash list...")
    frequency_sorted = frequency_sorted.tolist()
    frequency_ind = frequency_ind.tolist()
    balance_list = [0 for i in range(vocab_size)]
    bucket_size = np.array([0 for i in range(args.num_experts)])
    for freq, ind in tqdm(zip(frequency_sorted, frequency_ind)):
        if freq == 0:
            balance_list[ind] = np.random.randint(low=0, high=args.num_experts)
        else:
            bucket_ind = np.argmin(bucket_size)
            bucket_size[bucket_ind] += freq
            balance_list[ind] = bucket_ind

    return balance_list


def main():
    args = args_parser()

    balance_list = get_balance_list(args)
    name = args.data_dir + "/NomicEmbedFullDataset/hash_lists/balance_hash_bucket_" + str(args.num_experts) + "_1M.pkl"
    with open(name, "wb") as file:
        pickle.dump(balance_list, file)
    print("Completed!")


if __name__ == "__main__":
    main()