import json
import os
from dotenv import load_dotenv
from openai import AzureOpenAI

#############  SETUP  #############

load_dotenv()  # loads environment variables from .env

api_key = os.getenv('api_key')
api_version = os.getenv('api_version')
azure_endpoint = os.getenv('azure_endpoint')

client = AzureOpenAI(
    api_version=api_version,
    azure_endpoint=azure_endpoint,
    api_key=api_key,
)

# This is your wrapper for calling Azure GPT-4, as shown in your snippet
def completion_with_backoff_mcopenai(**kwargs):
    result = client.chat.completions.create(
        model="gpt-4o-mini",
        **kwargs,
    )
    return result


def build_judge_prompt(problem_id, ground_truth_list, ground_truth_cost, llama_response):
    """
    Builds the user message that GPT-4 will see to decide if the LLaMA solution
    matches the ground truth. GPT-4 will output CORRECT or WRONG based on
    the instructions.
    
    You can adjust the instructions and format as needed.
    """
    
    # We provide a system message to keep GPT-4 'on task'
    system_prompt = (
        "You are a helpful assistant that checks if one solution to an assignment problem "
        "matches the known correct solution. Please only output 'CORRECT' or 'WRONG' "
        "with a short one-sentence explanation."
    )

    # The user message includes the ground truth and the LLaMA's response
    user_prompt = (
        "Your task is to label the model's solution to an assignment problem as 'CORRECT' or 'WRONG'.\n\n"
        f"Problem ID: {problem_id}\n\n"
        f"GROUND TRUTH:\n"
        f"  solution_list: {json.dumps(ground_truth_list)}\n"
        f"  solution_cost: {ground_truth_cost}\n\n"
        "MODEL'S SOLUTION:\n"
        f"{llama_response}\n\n"
        "First, give a single short sentence explaining your reasoning, and then output CORRECT or WRONG. "
        "If the cost and matching pairs match exactly, label CORRECT. Otherwise, label WRONG."
    )

    # user_prompt += "If the output includes a reasonable code solution, even if the cost is wrong, you can also label it as 'CORRECT'.\n\n"

    # Return the ChatCompletion message list
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    return messages

def judge_solution(problem_id, ground_truth_list, ground_truth_cost, llama_response):
    """
    Uses Azure GPT-4 to judge if the 'llama_response' is correct or not.
    Returns "CORRECT" or "WRONG" (as determined by GPT-4).
    """

    # 1) Build the prompt
    messages = build_judge_prompt(problem_id, ground_truth_list, ground_truth_cost, llama_response)

    # 2) Call GPT-4 (with temperature=0 to reduce randomness)
    try:
        response = completion_with_backoff_mcopenai(
            messages=messages,
            temperature=0,
            max_tokens=100
        )
        content = response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error in GPT-4 judge for problem {problem_id}: {e}")
        return "WRONG"

    # 3) We expect something like "Some short explanation. CORRECT" or "... WRONG"
    # Let's extract the final label from the content by searching for the words CORRECT or WRONG.
    # If not found, default to WRONG.
    content_upper = content.upper()
    print(content_upper)
    if "CORRECT" in content_upper and "WRONG" not in content_upper:
        return "CORRECT"
    elif "WRONG" in content_upper and "CORRECT" not in content_upper:
        return "WRONG"
    print("ambiguous")
    # If the response contains both or none, we consider it ambiguous => WRONG
    return "WRONG"

#############  MAIN SCRIPT  #############

def main():
    # 1) Load the LLaMA responses
    input_file = "result/assignment_model_responses.json"
    with open(input_file, "r") as f:
        model_responses = json.load(f)

    total = len(model_responses)
    correct_count = 0

    # 2) Loop over each record
    for entry in model_responses:
        problem_id = entry["problem_id"]
        llama_response = entry["model_response"] or ""
        gt_list = entry["solution_list_ground_truth"]
        gt_cost = entry["solution_cost_ground_truth"]

        # 3) Have GPT-4 judge correctness
        judge_label = judge_solution(problem_id, gt_list, gt_cost, llama_response)

        # 4) Track how many are CORRECT
        if judge_label == "CORRECT":
            correct_count += 1
        
        # Print or store the label if desired
        print(f"Problem {problem_id} => GPT-4 Judge: {judge_label}")

    # 5) Compute accuracy
    accuracy = correct_count / total if total > 0 else 0.0
    print(f"\nFinal Accuracy: {accuracy:.2%} ({correct_count}/{total} correct)")

if __name__ == "__main__":
    main()
