import argparse
import json

from tqdm import tqdm


def filter_queries_above_threshold(input_file, output_file, threshold=50):
    """Filters queries with a score above the threshold and saves them to a new file."""
    with open(input_file, "r", encoding="utf-8") as jsonl_file:
        lines = jsonl_file.read().strip().split("\n")

    with open(output_file, "w", encoding="utf-8") as output_writer:
        for line in tqdm(lines, desc="Filtering Queries"):
            try:
                data = json.loads(line)

                if "relevant_queries" in data:
                    filtered_queries = [
                        query
                        for query in data["relevant_queries"]
                        if isinstance(query, dict)
                        and query.get("score", 0) >= threshold
                    ]

                    if filtered_queries:
                        data["relevant_queries"] = filtered_queries
                        output_writer.write(json.dumps(data, ensure_ascii=False) + "\n")

            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e}")


def count_total_queries(file_path):
    """Counts the total number of queries in the filtered file."""
    total_queries = 0

    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

        for line in tqdm(lines, desc="Counting Queries"):
            try:
                data = json.loads(line)
                if "relevant_queries" in data:
                    total_queries += len(data["relevant_queries"])

            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e}")

    return total_queries


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Filter queries based on relevancy score"
    )
    parser.add_argument(
        "-i", "--input_file", type=str, required=True, help="Path to input JSONL file"
    )
    parser.add_argument(
        "-o", "--output_file", type=str, required=True, help="Path to output JSONL file"
    )
    parser.add_argument(
        "-t", "--threshold", type=int, required=True, help="Relavancy score threshold"
    )

    args = parser.parse_args()

    # Filter queries and save to file
    filter_queries_above_threshold(args.input_file, args.output_file, args.threshold)
    print(f"Filtered results saved to {args.output_file}")

    # Count total number of queries in the filtered file
    total_queries = count_total_queries(args.output_file)
    print(f"Total number of queries (Q): {total_queries}")
