import argparse
import torch
import numpy as np
import os
import json
import glob
from tqdm import tqdm

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mapping_path", type=str, default=".")
    args = parser.parse_args()
    
    mapping_files = sorted(glob.glob(args.mapping_path + '_[0-9]*.npz'))

    # Note that the mapping function is accumulasted based on their only chunk.
    # Hence, we have to make them global when merging.
    mappings = dict()
    que_value = 0
    for mapping_file in mapping_files:
        mapping_i = dict(np.load(mapping_file, allow_pickle=True)) # NpzFile is not mutable, hence map to mutable dictionary
        for key in tqdm(mapping_i, total=len(mapping_i)):
            mapping_i[key] += que_value
        que_value = mapping_i[next(reversed(mapping_i))][-1] + 1  # update que_value

        mappings.update(mapping_i)

    # It is better to have the list as values.
    for key in tqdm(mappings, total=len(mappings)):
        mappings[key] = mappings[key].tolist()
    # Save the merged mapping dictionary
    with open(args.mapping_path + '.json', 'w') as f:
        json.dump(mappings, f)

    for mapping_file in mapping_files:
        os.remove(mapping_file)
