import argparse
import json
import pyarrow as pa
import pyarrow.parquet as pq

schema = pa.schema([
    ('task_id', pa.string()),
    ('prompt_codegen', pa.string()),
    ('prompt_testgen', pa.string()),
])


def read_jsonl(fp):
    with open(fp) as fin:
        for line in fin:
            try:
                d = json.loads(line)
                yield d
            except Exception as ex:
                print(ex, line)
                continue

def main_each(input_jsonl, output_parquet):
    prompt_array = []
    test_array = []
    task_array = []
    for d in read_jsonl(input_jsonl):
        task_id = d.get("idx")
        prompt = d.get("prompt_codegen", "")
        if len(prompt) == 0:
            continue
        test_prompt = d.get("prompt_test","")
        if len(test_prompt) == 0:
            continue
        prompt_array.append(prompt)
        test_array.append(test_prompt)
        task_array.append(task_id)
    tasks = pa.array(task_array, type = pa.string())
    prompts = pa.array(prompt_array, type = pa.string())
    responses = pa.array(test_array, type = pa.string())
    batch = pa.RecordBatch.from_arrays(
        [tasks, prompts, responses],
        schema = schema
    )
    table = pa.Table.from_batches([batch])
    pq.write_table(table, output_parquet)

def main(args):
    import os
    import sys


    base_dir = args.input_dir
    all_parquet_paths = os.listdir(base_dir)
    all_parquet_paths = [os.path.join(base_dir, x) for x in all_parquet_paths if x.endswith(".jsonl")]

    for each_parquet_path in all_parquet_paths:
        save_parquet_path = str(each_parquet_path).split("/")[-1].replace("jsonl", "parquet")
        save_parquet_path = os.path.join(args.output_dir, save_parquet_path)
        main_each(each_parquet_path, save_parquet_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', '-i', type=str, required=True, help='input file path')
    parser.add_argument('--output_dir', '-o', type=str, required=True, help='output partquet file path')
    args = parser.parse_args()
    main(args)
