import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import json
import tiktoken
from model.request_gpt import *
from model.request_gemini import *
from model.prompts import *
from rag_setup import *
from utils import *


def run_single_llm_gsm_plus(input_dir, output_dir, weak_llm_model, llm_api_key, problems_indices=None, context_type=None, context_dir=None, threshold_similarity=None, k=2):
    """
    Run with a single LLM to solve problems.
    """

    # read json files
    with open(input_dir, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    if problems_indices is None:
        problems_indices = list(range(len(data)))

    if context_type == "strategy" and context_dir is not None:  
        with open(context_dir, "r", encoding="utf-8") as f:
            data_context = json.load(f)
        data_context = [
            d for d in data_context 
            if d["strong_llm_strategy"] != "No strategy"
        ]

        folder_path = "test_gsm_plus_strong_gemini2.0flash_llm_rag_database"
        if os.path.exists(folder_path) and os.path.isdir(folder_path):
            rag_db = PersistentRAG("test_gsm_plus_strong_gemini2.0flash_llm")
        else:
            rag_db = PersistentRAG("test_gsm_plus_strong_gemini2.0flash_llm")
            rag_db.add_documents(data_context)

    # initialize LLMs
    llm = ChatGPTClient(api_key=llm_api_key, model=weak_llm_model)
    encoding = tiktoken.encoding_for_model(weak_llm_model)

    check_llm = GeminiClient(project_id="", model="gemini-2.0-flash")

    total_correct = 0
    for problem_ind in problems_indices:
        print(f"\nProcessing problem index: {problem_ind}")
        question = data[problem_ind]["question"]

        # generate the prompt
        if context_type == "strategy" and context_dir is not None:  
            relevant_docs = rag_db.retrieve(question, k=k)
            if threshold_similarity is not None:
                relevant_docs = [doc for doc in relevant_docs if doc["problem_score"] > threshold_similarity]
            context = format_context(relevant_docs, strategy_key_name="strong_llm_strategy", answer_key_name="strong_llm_answer")
            prompt = fewshot_prompt_gsm.format(context=context, question=question)
        else:
            prompt = zero_prompt_gsm.format(question=question)
        print(f"\nPrompt for llm...\n{prompt}\n")

        # get the llm response
        llm_response = llm.ask(
            prompt,
            logprobs=True,  # ask for log-probs
            top_logprobs=1  # we only care about the chosen token
        )
        print(f"LLM response...\n\n{llm_response}\n")

        llm_strategy, llm_answer = extract_response_components(llm_response)
        if llm_answer != "No answer":
            if llm_answer == data[problem_ind]["ground_truth"]:
                correctness = True
            else:
                check_correct = check_correct_prompt.format(llm_answer=llm_answer, ground_truth=data[problem_ind]["ground_truth"])
                check_correct_response = check_llm.ask(check_correct)
                correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                correctness = correctness_str == "true"
        else:
            correctness = False
        total_correct += correctness
        print(f"\nLLM answer is {llm_answer}, and ground truth is {data[problem_ind]['ground_truth']}")
        print(f"Match or not? {correctness}\n")

        if llm_answer != "No answer":
            # count how many tokens that answer took
            answer_ids = encoding.encode(llm_answer, disallowed_special=())
            n_answer_tokens = len(answer_ids)

            # pull out all of the token
            entries = llm_response.logprobs.content
            token_lps = [e.logprob for e in entries]

            # sum the last n_answer_tokens of them
            answer_token_lps = token_lps[-n_answer_tokens:]
            answer_logprob = sum(answer_token_lps)
        else:
            answer_logprob = None
        print(f"LLM answer has log probability {answer_logprob}")

        # update the test data with llm answer
        llm_pre = "weak"
        if context_type == "strategy" and context_dir is not None:  
            if threshold_similarity is not None:
                llm_pre = f"weak_with_context_sim{threshold_similarity}_k{k}"
            else:
                llm_pre = f"weak_with_context_no_sim_k{k}"
        safe_update_data(problem_ind, {
            f"{llm_pre}_llm_prompt": prompt,
            f"{llm_pre}_llm_response": llm_response.message.content,
            f"{llm_pre}_llm_answer_tokens": n_answer_tokens if llm_answer != "No answer" else None,
            f"{llm_pre}_llm_strategy": llm_strategy,
            f"{llm_pre}_llm_answer": llm_answer,
            f"{llm_pre}_llm_accuracy": correctness,
            f"{llm_pre}_llm_answer_logprob_sum": answer_logprob,
        }, output_dir)

    print(f"\nAccuracy of weak LLM: {total_correct / len(problems_indices) * 100:.2f}%") 
                
if __name__ == "__main__":

    weak_llm_api_key = ""
    weak_llm_model = "gpt-3.5-turbo"
    problems_indices = None
    
    # weak llm on calibration dataset: accuracy 241/1048=23.00%
    calibration_dir = "data/gsm_plus/calibration_gsm_plus.json"
    calibration_weak_llm_dir = "data/gsm_plus/calibration_gsm_plus_weak_gpt3.5turbo_llm.json"
    run_single_llm_gsm_plus(calibration_dir, calibration_weak_llm_dir, weak_llm_model, weak_llm_api_key, problems_indices=problems_indices)
    print(compute_accuracy(calibration_weak_llm_dir, "weak_llm_accuracy"))

