import json
import argparse

def filter_files(fewshot_file, searchr1_file):
    # Read fewshot file
    try:
        with open(fewshot_file, 'r') as f:
            fewshot_data = [json.loads(line) for line in f]
    except FileNotFoundError:
        print(f"Error: File not found: {fewshot_file}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON in {fewshot_file}")
        return

    # Find questions to remove (where do_search is false)
    questions_to_remove = {
        item['question'] for item in fewshot_data if not item.get('do_search', True)
    }

    print(f"Found {len(questions_to_remove)} questions to remove.")

    # Filter fewshot data
    filtered_fewshot_data = [
        item for item in fewshot_data if item['question'] not in questions_to_remove
    ]

    # Read searchr1 file
    try:
        with open(searchr1_file, 'r') as f:
            searchr1_data = [json.loads(line) for line in f]
    except FileNotFoundError:
        print(f"Error: File not found: {searchr1_file}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON in {searchr1_file}")
        return


    # Filter searchr1 data
    filtered_searchr1_data = [
        item for item in searchr1_data if item.get('question') not in questions_to_remove
    ]
    
    # Overwrite the original files
    with open(fewshot_file, 'w') as f:
        for item in filtered_fewshot_data:
            f.write(json.dumps(item) + '\n')

    with open(searchr1_file, 'w') as f:
        for item in filtered_searchr1_data:
            f.write(json.dumps(item) + '\n')

    print(f"Filtered {len(fewshot_data) - len(filtered_fewshot_data)} items from {fewshot_file}.")
    print(f"Filtered {len(searchr1_data) - len(filtered_searchr1_data)} items from {searchr1_file}.")
    print("Files have been updated in place.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Filter jsonl files based on do_search flag.")
    parser.add_argument("fewshot_file", help="Path to the fewshot file to filter.")
    parser.add_argument("searchr1_file", help="Path to the SearchR1 file to filter.")
    args = parser.parse_args()

    filter_files(args.fewshot_file, args.searchr1_file) 