import json

input_file = './gsm8k_reasoning_verified.json'
output_file = './gsm8k_reasoning_processed.json'

# Load the original JSON data
with open(input_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

# Store processed results
processed_data = []

voting_corrects_count = 0
verified_corrects_count = 0

for item in data:
    # Keep only the required fields
    filtered_item = {
        "question": item.get("question"),
        "answer": item.get("answer"),
        "final_answer": item.get("final_answer"),
        "index_number": item.get("index_number"),
        "model_input": item.get("model_input"),
        "n_responses": item.get("n_responses"),
        "voting_results": item.get("voting_results"),
        "voting_corrects": item.get("voting_corrects"),
        "discard_scores": item.get("discard_scores")
    }

    # Count voting_corrects if it's [True]
    voting_correct_list = item.get("voting_corrects", [])
    if voting_correct_list and voting_correct_list[0] is True:
        voting_corrects_count += 1

    # Determine verified_answer and verified_corrects
    discard_scores = item.get("discard_scores", {})
    if discard_scores:
        verified_answer = min(discard_scores, key=lambda k: discard_scores[k])
        filtered_item["verified_answer"] = verified_answer
        try:
            is_verified_correct = str(verified_answer) == str(item.get("final_answer"))
        except ValueError:
            is_verified_correct = False  # 无法转成 int，就默认不正确
        filtered_item["verified_corrects"] = [is_verified_correct]
        if is_verified_correct:
            verified_corrects_count += 1
    else:
        filtered_item["verified_answer"] = None
        filtered_item["verified_corrects"] = [False]

    processed_data.append(filtered_item)

# Save the new JSON data
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(processed_data, f, indent=2, ensure_ascii=False)

# Accuracy stats
total = len(data)
voting_accuracy = voting_corrects_count / total if total > 0 else 0
verified_accuracy = verified_corrects_count / total if total > 0 else 0

print(f"Processing complete. Results saved to {output_file}")
print(f"Voting Corrects Accuracy: {voting_accuracy:.2%}")
print(f"Verified Corrects Accuracy: {verified_accuracy:.2%}")