import json
import argparse
import os

def process_jsonl(input_path, output_path):
    """
    Processes a JSONL file to update the 'do_search' flag based on the content of 'sequences_str'.
    """
    print(f"Processing {input_path}...")
    updated_lines = 0
    with open(input_path, 'r', encoding='utf-8') as infile, open(output_path, 'w', encoding='utf-8') as outfile:
        for i, line in enumerate(infile):
            try:
                data = json.loads(line)
                original_do_search = data.get('do_search')
                original_sequence = data['sequences_str']
                assistant_start_marker = '<|im_start|>assistant'
                assistant_start_index = original_sequence.rfind(assistant_start_marker)
                if assistant_start_index != -1:
                    agent_actions = original_sequence[assistant_start_index:]
                    has_search = '<search>' in agent_actions
                    new_do_search = has_search
                    data['do_search'] = new_do_search
                    updated_lines += 1
                
                outfile.write(json.dumps(data, ensure_ascii=False) + '\n')
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Could not process line {i+1} due to {type(e).__name__}. Writing it as is. Line: {line.strip()}")
                outfile.write(line)

    print(f"Processing complete. Updated {updated_lines} lines. Output written to {output_path}")

def main():
    parser = argparse.ArgumentParser(
        description="Update 'do_search' flag in a JSONL file based on '<information>' tag in 'sequences_str'."
    )
    parser.add_argument("input_file", type=str, help="Path to the input JSONL file.")
    parser.add_argument("-o", "--output_file", type=str, help="Path to the output JSONL file. If not provided, it will overwrite the input file.")
    
    args = parser.parse_args()

    input_file = args.input_file
    output_file = args.output_file

    if not output_file:
        output_file = input_file
        print(f"Warning: No output file provided. The input file '{input_file}' will be overwritten.")

    # To avoid data loss when overwriting, we write to a temporary file first.
    tmp_output_file = f"{output_file}.tmp"
    
    process_jsonl(input_file, tmp_output_file)

    os.rename(tmp_output_file, output_file)
    print(f"Successfully updated file: {output_file}")

if __name__ == "__main__":
    main() 