import argparse
import json


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, nargs='+', required=True, help='Input JSONL files')
    parser.add_argument('--src-data', type=str, required=True, help='Question data JSON file')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite existing files')
    args = parser.parse_args()

    # Load the source data
    with open(args.src_data, 'r') as f:
        question_data = json.load(f)

    for file in args.input:
        model_outs = []
        with open(file, 'r') as f:
            for line in f:
                model_outs.append(json.loads(line))
        for model_out in model_outs:
            question_id = model_out['qid']
            if question_id not in question_data:
                raise ValueError(f"Question ID {question_id} not found in source data.")
            model_out['short_answer'] = question_data[question_id]['answer']
            assert model_out['question'].strip().lower() == question_data[question_id]['question'].strip().lower(), \
                f"Check failed ({question_id}): {model_out['question']} vs {question_data[question_id]['question']}"
        # Save the augmented data
        output_file = file.replace('.jsonl', '_augmented.jsonl')
        if args.overwrite:
            output_file = file
        with open(output_file, 'w') as f:
            for model_out in model_outs:
                f.write(json.dumps(model_out) + '\n')

if __name__ == '__main__':
    main()