import argparse
import os

import numpy as np

from pile_index import data_to_dict


def main(args):
    # Clear previous samples
    open("samples/samples.txt", "w").close()

    data_path = os.path.join("pile/train", args.data_file)
    data_dict = data_to_dict(data_path)

    random_indices = np.array(
        np.random.choice(len(data_dict), args.nsamples, replace=False), dtype=np.int64
    )
    data_items = [data_dict[i] for i in random_indices]

    separator = "\n\n" + 50 * "-" + "\n\n" + 50 * "-" + "\n\n"
    with open("samples/samples.txt", "w") as file:
        for data_item in data_items:
            file.write(separator)
            file.write(data_item)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_file", type=str, default="02.jsonl")
    parser.add_argument("--nsamples", type=int, default=100)
    args = parser.parse_args()
    main(args)
