
import os
import sys

import numpy as np

try:
    input_filename = sys.argv[1]
    num_shards = int(sys.argv[2])
    outdir = sys.argv[3]
except:
    print("Usage: python shard_file.py <input_filename> <num_shards> <outdir>")

def read_lines(input_filename):
    lines = []
    with open(input_filename, "r") as fin:
        for line in fin:
            lines.append(line)
    return lines

def write_to(output_filename, lines):
    with open(output_filename, "w+") as fout:
        for line in lines:
            fout.write(line)

lines = read_lines(input_filename)
N = len(lines)
all_shard_indices = np.array_split(np.arange(N), num_shards)
basename = os.path.basename(input_filename)
for i, shard_indices in enumerate(all_shard_indices):
    shard_indices = shard_indices.tolist()
    output_filename = os.path.join(outdir, basename + ".{}-{}".format(i+1, num_shards))
    write_to(output_filename, [lines[index] for index in shard_indices])