import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import json
from model.request_gemini import *
from model.prompts import *
from utils import *


def run_single_llm_gsm_plus(input_dir, output_dir, strong_llm_model, project_id, problems_indices=None):
    """
    Run with a single LLM to solve problems.
    """

    with open(input_dir, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    if problems_indices is None:
        problems_indices = list(range(len(data)))

    llm = GeminiClient(project_id=project_id, model=strong_llm_model)

    total_correct = 0
    for problem_ind in problems_indices:
        print(f"\nProcessing problem index: {problem_ind}")
        question = data[problem_ind]["question"]
        prompt = zero_prompt_gsm.format(question=question)
        print(f"\nPrompt for llm...\n{prompt}\n")

        llm_response = llm.ask(
            prompt,
            logprobs=True,
            top_logprobs=1
        )
        print(f"LLM response...\n\n{llm_response.message.content}\n")

        llm_strategy, llm_answer = extract_response_components(llm_response)
        if llm_answer != "No answer":
            if llm_answer == data[problem_ind]["ground_truth"]:
                correctness = True
            else:
                check_correct = check_correct_prompt.format(llm_answer=llm_answer, ground_truth=data[problem_ind]["ground_truth"])
                check_correct_response = llm.ask(check_correct)
                correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                correctness = correctness_str == "true"
        else:
            correctness = False
        total_correct += correctness
        print(f"\nLLM answer is {llm_answer}, and ground truth is {data[problem_ind]['ground_truth']}")
        print(f"Match or not? {correctness}\n")

        if llm_answer != "No answer":
            token_list = llm_response.logprobs.content
            for i, token_obj in enumerate(token_list):
                if token_obj.token.strip() == "Answer" and token_list[i + 1].token.strip() == "]:":
                    start_index = i + 2
            while start_index < len(token_list) and token_list[start_index].token.strip() == "":
                start_index += 1
            if token_list[-1].token.strip() == "":
                end_index = len(token_list) - 1
            else:
                end_index = len(token_list)
            print("Answer token and logprob ...")
            for i in range(start_index, end_index):
                print(f"  - Token '{token_list[i].token}' has log-probability: {token_list[i].logprob}")
                
            token_lps = [token_obj.logprob for token_obj in token_list]
            answer_token_lps = token_lps[start_index:end_index]
            answer_logprob = sum(answer_token_lps)
        else:
            answer_logprob = None
        print(f"\nLLM answer has log probability {answer_logprob}")

        llm_pre = "strong"
        safe_update_data(problem_ind, {
            f"{llm_pre}_llm_prompt": prompt,
            f"{llm_pre}_llm_response": llm_response.message.content,
            f"{llm_pre}_llm_answer_tokens": end_index - start_index if llm_answer != "No answer" else None,
            f"{llm_pre}_llm_strategy": llm_strategy,
            f"{llm_pre}_llm_answer": llm_answer,
            f"{llm_pre}_llm_accuracy": correctness,
            f"{llm_pre}_llm_answer_logprob_sum": answer_logprob,
        }, output_dir)

    print(f"\nAccuracy of strong LLM: {total_correct / len(problems_indices) * 100:.2f}%") 

if __name__ == "__main__":

    project_id = ""
    # strong_llm_api_key = ""
    strong_llm_model = "gemini-2.0-flash"
    
    # 1. strong llm on calibration dataset: accuracy 771/1048=73.57%
    calibration_dir = "data/gsm_plus/calibration_gsm_plus.json"
    calibration_strong_llm_dir = "data/gsm_plus/calibration_gsm_plus_strong_gemini2.0flash_llm.json"
    problems_indices = None
    run_single_llm_gsm_plus(calibration_dir, calibration_strong_llm_dir, strong_llm_model, project_id, problems_indices=problems_indices)
    print(compute_accuracy(calibration_strong_llm_dir, "strong_llm_accuracy"))
   
    # 2. strong llm on test dataset: accuracy 6590/9504=69.34%
    test_dir = "data/gsm_plus/test_gsm_plus.json"
    test_strong_llm_dir = "data/gsm_plus/test_gsm_plus_strong_gemini2.0flash_llm.json"
    problems_indices = None
    run_single_llm_gsm_plus(test_dir, test_strong_llm_dir, strong_llm_model, project_id, problems_indices=problems_indices)
    print(compute_accuracy(test_strong_llm_dir, "strong_llm_accuracy"))

