# merge multiple datasets to obtain a larger one.
import json
import argparse
import os
from tqdm import tqdm

def merge_datasets(input_paths, output_path):
    merged_data = []
    for input_path in tqdm(input_paths, desc="Loading datasets"):
        with open(input_path, 'r') as f:
            data = json.load(f)
            print(f"Loaded {len(data)} samples from {input_path}")
            merged_data.extend(data)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(merged_data, f)
    print(f"Merged {len(merged_data)} samples")

    # Save statistics to a json file
    stats_path = output_path.replace('.json', '_stats.json')
    with open(stats_path, 'w') as file:
        json.dump({
            "n_samples_input": [{'path': input_path, 'n_samples': len(json.load(open(input_path, 'r')))} for input_path in input_paths],
            "n_samples_output": len(merged_data),
        }, file, indent=2)
    

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_paths', type=str, nargs='+', required=True, help='Paths for the input files')
    parser.add_argument('--output_path', type=str, required=True, help='Path to the output file')
    args = parser.parse_args()

    print(f"Merging datasets: {args.input_paths} into {args.output_path}")

    merge_datasets(args.input_paths, args.output_path)

if __name__ == "__main__":
    main()