import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import json
from model.request_gemini import *
from model.prompts import *
from utils import *


def run_single_llm_nasa_history(input_dir, output_dir, strong_llm_model, project_id, problems_indices=None):
    """
    Run with a single LLM to solve problems.
    """

    with open(input_dir, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    if problems_indices is None:
        problems_indices = list(range(len(data)))

    llm = GeminiClient(project_id=project_id, model=strong_llm_model)

    total_correct = 0
    for problem_ind in problems_indices:
        print(f"\nProcessing problem index: {problem_ind}")
        question = data[problem_ind]["question"]
        prompt = zero_prompt_nasa_history.format(question=question)
        print(f"\nPrompt for llm...\n{prompt}\n")

        llm_response = llm.ask(
            prompt,
            logprobs=True,
            top_logprobs=1
        )
        print(f"LLM response...\n\n{llm_response.message.content}\n")

        llm_strategy, llm_answer = extract_response_components(llm_response)
        total_correct += llm_answer == data[problem_ind]["ground_truth"]
        print(f"\nLLM answer is {llm_answer}, and ground truth is {data[problem_ind]['ground_truth']}")
        print(f"Match or not? {llm_answer == data[problem_ind]['ground_truth']}\n")

        if llm_answer != "No answer":
            token_list = llm_response.logprobs.content
            for i, token_obj in enumerate(token_list):
                if token_obj.token.strip() == "Answer" and token_list[i + 1].token.strip() == "]:":
                    start_index = i + 2
            if token_list[-1].token.strip() == "":
                end_index = len(token_list) - 1
            else:
                end_index = len(token_list)
            print("Answer token and logprob ...")
            for i in range(start_index, end_index):
                print(f"  - Token '{token_list[i].token}' has log-probability: {token_list[i].logprob}")
                
            token_lps = [token_obj.logprob for token_obj in token_list]
            answer_token_lps = token_lps[start_index:end_index]
            answer_logprob = sum(answer_token_lps)
        else:
            answer_logprob = None
        print(f"\nLLM answer has log probability {answer_logprob}")

        llm_pre = "strong"
        safe_update_data(problem_ind, {
            f"{llm_pre}_llm_prompt": prompt,
            f"{llm_pre}_llm_response": llm_response.message.content,
            f"{llm_pre}_llm_answer_tokens": end_index - start_index if llm_answer != "No answer" else None,
            f"{llm_pre}_llm_strategy": llm_strategy,
            f"{llm_pre}_llm_answer": llm_answer,
            f"{llm_pre}_llm_accuracy": llm_answer == data[problem_ind]["ground_truth"],
            f"{llm_pre}_llm_answer_logprob_sum": answer_logprob,
        }, output_dir)

    print(f"\nAccuracy of strong LLM: {total_correct / len(problems_indices) * 100:.2f}%") 

if __name__ == "__main__":

    project_id = ""
    # strong_llm_api_key = ""
    strong_llm_model = "gemini-2.0-flash"
    problems_indices = None

    # 1. strong llm on calibration dataset: accuracy 788/1000=78.80%
    calibration_dir = "data/nasa_history/calibration_nasa_history.json"
    calibration_strong_llm_dir = "data/nasa_history/calibration_nasa_history_strong_gemini2.0flash_llm.json"
    run_single_llm_nasa_history(calibration_dir, calibration_strong_llm_dir, strong_llm_model, project_id, problems_indices=problems_indices)
    print(compute_accuracy(calibration_strong_llm_dir, "strong_llm_accuracy"))
   

