# find_verifier_disagreements.py

import pandas as pd
import json
import argparse

from verl.utils.experience_maker import preprocess_box_response_for_qwen_prompt
from verl.utils.reward_score import prime_math, math_verify

PROMPT_COL = "prompt"
RESPONSE_COL = "responses"
GROUND_TRUTH_COL = "reward_model"
GROUND_TRUTH_KEY = "ground_truth"


def find_divergent_samples(input_path: str, output_path: str) -> None:
    try:
        df = pd.read_parquet(input_path)
        print(f"Loaded {len(df)} rows from {input_path}")
    except Exception as e:
        print(f"Error loading input Parquet file '{input_path}': {e}")
        return

    verifiers = {
        "qwen": preprocess_box_response_for_qwen_prompt,
        "prime": lambda resp, ans: float(prime_math.compute_score(resp, ans)["acc"]),
        "math": lambda resp, ans: float(math_verify.compute_score(resp, ans)["acc"]),
    }
    verifier_names = list(verifiers.keys())  # ["qwen", "prime", "math"]

    divergent_samples = []

    for index, row in df.iterrows():
        try:
            # 1. Get Ground Truth
            gt_data = row[GROUND_TRUTH_COL]
            if isinstance(gt_data, str):
                gt_dict = json.loads(gt_data)
            else:
                gt_dict = gt_data
            ground_truth_answer = str(gt_dict[GROUND_TRUTH_KEY])

            # 2. Get Original Question Content
            prompt_list = row[PROMPT_COL]
            # Assuming the structure is like [{'content': '...', 'role': 'user'}]
            question_content = prompt_list[0]["content"]

            # 3. Get Responses
            generated_responses = row[RESPONSE_COL]

            # 4. Compare each response using all verifiers
            for response in generated_responses:
                response_str = str(response)
                results = {}
                try:
                    for name in verifier_names:
                        results[name] = verifiers[name](
                            response_str, ground_truth_answer
                        )

                    # Check for disagreement
                    first_result = results[verifier_names[0]]
                    if not all(
                        results[name] == first_result for name in verifier_names
                    ):
                        divergent_samples.append(
                            {
                                "question": question_content,
                                "response": response_str,
                                "ground_truth": ground_truth_answer,
                                "qwen_result": results.get("qwen", None),
                                "prime_result": results.get("prime", None),
                                "math_result": results.get("math", None),
                                "original_index": index,  # Optional: Keep track of original row
                            }
                        )
                except Exception as verifier_error:
                    print(
                        f"Warning: Error during verification for row {index}, response snippet '{response_str[:50]}...': {verifier_error}"
                    )

        except Exception as e:
            print(f"Warning: Skipping row {index} due to main processing error: {e}")
            continue

    # --- Save the divergent samples ---
    if divergent_samples:
        output_df = pd.DataFrame(divergent_samples)
        try:
            output_df.to_parquet(output_path, index=False)
            print(f"\nFound {len(divergent_samples)} disagreements.")
            print(f"Saved divergent samples to: {output_path}")
        except Exception as e:
            print(f"\nError saving output Parquet file '{output_path}': {e}")
    else:
        print("\nNo disagreements found between the verifiers.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Find samples where different verifiers disagree and save them."
    )
    parser.add_argument(
        "--input_path",
        required=True,
        help="Path to the input Parquet file containing responses and ground truth.",
    )
    args = parser.parse_args()

    find_divergent_samples(args.input_path, f"{args.input_path}.diff.parquet")
