import json
import re
from utils.gpt_utils import ChatBot
from extract_final_answer import extract_final_question
from utils.basic import remove_prefix
from output_preprocessing import extract_yes_or_no


def custom_verification_prompt(reasoning_chain):

    prompt = f"""Check a  reasoning chain step-by-step and judge whether it is overall correct.

Here are two examples:

Example 1 (Condition Error):
Reasoning Chain:
"""
    prompt += """1. The question is "What is 5 + 7?"
2. The student claims "5 + 8 = 13".
3. Therefore, the answer is 13.

Analysis:
The student's initial setup is wrong: they mistakenly used 5 + 8 instead of 5 + 7. Thus, the condition is incorrect.

Final Verdict: No
"""

    prompt += """

Example 2 (Logic Error):
Reasoning Chain:
"""
    prompt += """1. The question is "What is the perimeter of a square with side length 4?"
2. The student correctly identifies that the perimeter formula is 4 × side length.
3. However, the student calculates 4 × 4 = 20.
4. Therefore, the answer is 20.

Analysis:
The setup is correct, but the calculation (logic) is wrong. 4 × 4 should be 16, not 20.

Final Verdict: No
"""

    # Now, real reasoning to verify:
    prompt += f"""

Now, please verify the following student's reasoning:

Reasoning Chain:
\"\"\"{reasoning_chain}\"\"\"

At the end, simply answer "Yes" if the reasoning is correct and aligned, or "No" if there is any mistake (either condition or logic).
"""

    return prompt



def call_model(prompt):
    """
    Wrapper to call the model.
    """
    return ChatBot.call_chat_gpt(
        prompt,
        eos_pattern=None,
        max_new_tokens=512,
        early_stopping=True,
        do_sample=True,
        return_list=True,
        temperature=0.7,
        num_beams=1,
        num_return_sequences=1,
        decoding_args=None
    )


def run_simple_cot_verification(model_output, question):
    """
    Run a simple CoT verification by checking the full reasoning chain.
    """
    prompt = custom_verification_prompt(model_output)
    model_response = call_model(prompt)

    # Extract Yes/No decision
    _, filtered_results = extract_yes_or_no(model_response)
    verification_trace = {
        "custom_verification_prompt": prompt,
        "model_outputs": model_response
    }
    final_pass = all(filtered_results)

    return verification_trace, final_pass


if __name__ == "__main__":
    input_path = "data/verification_test_50/du_50_test.json"
    output_path = "data/verification_test_50/normal_du.json"

    # 1. Load data
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # 2. Verification process
    verified = []
    for idx, item in enumerate(data, 1):
        question = item["question"]
        candidate_answer = item["candidate_answer"]

        print(f"Processing example {idx}...")

        verification_trace, final_pass = run_simple_cot_verification(candidate_answer, question)

        item["overall_pass"] = final_pass
        item["verification_trace"] = verification_trace

        verified.append(item)

    # 3. Save results
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(verified, f, indent=2, ensure_ascii=False)

    # 4. Compute overall accuracy
    total = len(verified)
    overall_match = 0

    for item in verified:
        if item.get("overall_pass") == item.get("overall_correct"):
            overall_match += 1

    print("\n===== Overall Pass Accuracy =====")
    print(f"Total samples: {total}")
    print(f"Overall Pass Accuracy: {overall_match / total:.2%} ({overall_match}/{total})")
    print(f"Verification completed. Results saved to {output_path}")
