import json
import argparse
import os


def split_data(data: list, n_parts: int) -> list:
    """
    Split data into n_parts
    
    Example:
    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    n_parts = 3
    split_data(data, n_parts) -> [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]
    """
    splitted_data = [data[i::n_parts] for i in range(n_parts)]
    assert len(splitted_data) == n_parts
    assert sum([len(part) for part in splitted_data]) == len(data)
    return splitted_data

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--n_parts", type=int, required=True)
    args = parser.parse_args()

    with open(args.input_path, "r") as f:
        data = json.load(f)

    splitted_data = split_data(data, args.n_parts)

    for i, part in enumerate(splitted_data):
        output_path = os.path.splitext(args.input_path)[0] + f"_part_{i}.json"
        with open(output_path, "w") as f:
            json.dump(part, f)
        print(f"Saved part {i} to {output_path}")
        print(f"Size of part {i}: {len(part)}")

    print(f"Total size of all parts: {sum([len(part) for part in splitted_data])}")

if __name__ == "__main__":
    main()




