import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import json
import tiktoken
from model.request_gpt import *
from model.request_gemini import *
from model.prompts import *
from find_threshold import *
from utils import *


def run_pipeline_metamath(input_dir, output_dir, problems_indices, weak_llm_api_key, strong_llm_project_id, threshold_weak, threshold_strong):

    # read source json file
    with open(input_dir, "r", encoding="utf-8") as f:
        test = json.load(f)

    if problems_indices is None:
        problems_indices = list(range(len(test)))

    # load calibration data for confidence scoring
    with open("data/metamath/calibration_metamath_strong_gemini2.0flash_llm.json", "r", encoding="utf-8") as f:
        calibration_gemini = json.load(f)
    calibration_gemini = [
        d for d in calibration_gemini 
        if d["strong_llm_strategy"] != "No strategy"
    ]
    # identify the parameters of transformation of the confidence_strong on calibration to fit the threshold
    strong_log_confidence = [item["strong_llm_answer_logprob_sum"] for item in calibration_gemini]
    _,  x_min_gemini, rng_gemini = minmax01(exp_from_logprobs(strong_log_confidence))

    with open("data/metamath/calibration_metamath_weak_gpt3.5turbo_llm.json", "r", encoding="utf-8") as f:
            calibration_weak = json.load(f)
    calibration_weak = [
        d for d in calibration_weak 
        if d["weak_llm_strategy"] != "No strategy"
    ]
    # identify the parameters of transformation of the confidence_weak on calibration to fit the threshold
    weak_log_confidence = [item["weak_llm_answer_logprob_sum"] for item in calibration_weak]
    _,  x_min_weak, rng_weak = minmax01(exp_from_logprobs(weak_log_confidence))

    # initialize LLMs
    weak_llm_model = "gpt-3.5-turbo"
    weak_llm = ChatGPTClient(api_key=weak_llm_api_key, model=weak_llm_model)
    weak_encoding = tiktoken.encoding_for_model(weak_llm_model)

    strong_llm_model = "gemini-2.0-flash"
    strong_llm = GeminiClient(project_id=strong_llm_project_id, model=strong_llm_model)

    stats = {
        "total_problems": len(problems_indices),
        "weak_helped_by_strategy": 0,
        "weak_total_correct": 0,
        "weak_llm_used": 0,
        "weak_llm_used_correct": 0,
        "strong_total_correct": 0,
        "strong_llm_used": 0,
        "unaccept_problems": 0,  
    }

    for problem_ind in problems_indices:
        # find the question
        print(f"Processing problem index: {problem_ind}")
        question = test[problem_ind]["question"]

        # use zero-shot prompt (no RAG retrieval)
        prompt = zero_prompt_gsm_symbolic.format(question=question)
        print(f"\nPrompt for llm...\n{prompt}\n")

        weak_llm_response = weak_llm.ask(
            prompt,
            logprobs=True,  # ask for log-probs (must be integer 0-20)
            top_logprobs=1  # we only care about the chosen token
        )
        weak_llm_strategy, weak_llm_answer = extract_response_components(weak_llm_response)
        if weak_llm_answer != "No answer":
            if weak_llm_answer == test[problem_ind]["ground_truth"]:
                correctness = True
            else:
                try:
                    if float(weak_llm_answer) == float(test[problem_ind]["ground_truth"]):
                        correctness = True
                    else:
                        raise ValueError
                except ValueError:
                    check_correct = check_correct_prompt.format(llm_answer=weak_llm_answer, ground_truth=test[problem_ind]["ground_truth"])
                    check_correct_response = strong_llm.ask(check_correct)
                    correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                    correctness = correctness_str == "true"
        else:
            correctness = False
        stats["weak_total_correct"] += correctness
        print(f"\nWeak LLM answer is {weak_llm_answer}, and ground truth is {test[problem_ind]['ground_truth']}")
        print(f"Match or not? {correctness}\n")

        # count how many tokens that answer took
        if weak_llm_answer != "No answer":
            weak_answer_ids = weak_encoding.encode(weak_llm_answer, disallowed_special=())
            weak_n_answer_tokens = len(weak_answer_ids)

            # pull out all of the token
            entries = weak_llm_response.logprobs.content
            token_lps = [e.logprob for e in entries]

            # sum the last n_answer_tokens of them
            weak_answer_token_lps = token_lps[-weak_n_answer_tokens:]
            weak_answer_logprob = sum(weak_answer_token_lps)
        else:
            weak_answer_logprob = None

        # update the test data with weak llm answer
        safe_update_data(problem_ind, {
            "weak_llm_helped_by_strategy": False,
            "weak_llm_prompt": prompt,
            "weak_llm_response": weak_llm_response.message.content,
            "weak_llm_answer_tokens": weak_n_answer_tokens if weak_llm_answer != "No answer" else None,
            "weak_llm_strategy": weak_llm_strategy,
            "weak_llm_answer": weak_llm_answer,
            "weak_llm_accuracy": correctness,
            "weak_llm_answer_logprob_sum": weak_answer_logprob,
        }, output_dir)

        # compute confidence and check threshold
        if weak_answer_logprob is not None:
            confidence_weak = inv_minmax01(exp_from_logprobs(weak_answer_logprob), x_min_weak, rng_weak)
            print(f"\nWeak LLM answer has confidence {confidence_weak}, threshold for weak LLM is {threshold_weak}")

        if weak_answer_logprob is not None and confidence_weak > threshold_weak:
            # accept weak llm answer
            stats["weak_llm_used"] += 1
            stats["weak_llm_used_correct"] += correctness
            print("Accept weak llm answer")

            safe_update_data(problem_ind, {
                "weak_llm_used": True,
                "weak_llm_used_correct": correctness,
            }, output_dir)
        else:
            safe_update_data(problem_ind, {
                "weak_llm_used": False,
                "weak_llm_used_correct": False,
            }, output_dir)

            # send to strong llm
            strong_llm_response = strong_llm.ask(
                prompt,
                logprobs=True,
                top_logprobs=1
            )
            strong_llm_strategy, strong_llm_answer = extract_response_components(strong_llm_response)
            if strong_llm_answer != "No answer":
                if strong_llm_answer == test[problem_ind]["ground_truth"]:
                    strong_correctness = True
                else:
                    try:
                        if float(strong_llm_answer) == float(test[problem_ind]["ground_truth"]):
                            strong_correctness = True
                        else:
                            raise ValueError
                    except ValueError:
                        check_correct = check_correct_prompt.format(llm_answer=strong_llm_answer, ground_truth=test[problem_ind]["ground_truth"])
                        check_correct_response = strong_llm.ask(check_correct)
                        strong_correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                        strong_correctness = strong_correctness_str == "true"
            else:
                strong_correctness = False
            stats["strong_total_correct"] += strong_correctness
            print(f"\nStrong LLM answer is {strong_llm_answer}, and ground truth is {test[problem_ind]['ground_truth']}")
            print(f"Match or not? {strong_correctness}\n")

            if strong_llm_answer != "No answer":
                token_list = strong_llm_response.logprobs.content
                for i, token_obj in enumerate(token_list):
                    if token_obj.token.strip() == "Answer" and token_list[i + 1].token.strip() == "]:":
                        start_index = i + 2
                while start_index < len(token_list) and token_list[start_index].token.strip() == "":
                    start_index += 1
                if token_list[-1].token.strip() == "":
                    end_index = len(token_list) - 1
                else:
                    end_index = len(token_list)
                print("Answer token and logprob ...")
                for i in range(start_index, end_index):
                    print(f"  - Token '{token_list[i].token}' has log-probability: {token_list[i].logprob}")

                token_lps = [token_obj.logprob for token_obj in token_list]
                strong_answer_token_lps = token_lps[start_index:end_index]
                strong_answer_logprob = sum(strong_answer_token_lps)
            else:
                strong_answer_logprob = None

            # update the test data with strong llm answer
            safe_update_data(problem_ind, {
                "strong_llm_prompt": prompt,
                "strong_llm_response": strong_llm_response.message.content,
                "strong_llm_answer_tokens": end_index - start_index if strong_llm_answer != "No answer" else None,
                "strong_llm_strategy": strong_llm_strategy,
                "strong_llm_answer": strong_llm_answer,
                "strong_llm_accuracy": strong_correctness,
                "strong_llm_answer_logprob_sum": strong_answer_logprob,
            }, output_dir)

            if strong_answer_logprob is not None:
                confidence_strong = inv_minmax01(exp_from_logprobs(strong_answer_logprob), x_min_gemini, rng_gemini)
                print(f"\nStrong LLM answer has confidence {confidence_strong}, threshold for strong LLM is {threshold_strong}")

            if strong_answer_logprob is not None and confidence_strong > threshold_strong:
                # accept strong llm answer
                stats["strong_llm_used"] += 1
                print("Accept strong llm answer")

                safe_update_data(problem_ind, {
                    "strong_llm_used": True,
                }, output_dir)
            else:
                stats["unaccept_problems"] += 1
                print("Reject both weak and strong llm answers")

                safe_update_data(problem_ind, {
                    "strong_llm_used": False,
                }, output_dir)

    print(f"Total number of problems: {stats['total_problems']}")
    print(f"Uncover problems: {stats['unaccept_problems']}\n")
    print(f"Accept weak llm: {stats['weak_llm_used']}")
    print(f"Number of correct answers of weak llm: {stats['weak_total_correct']}")
    print(f"Number of correct answers of weak llm that are used: {stats['weak_llm_used_correct']}")
    print(f"Number of weak llm helped by strategy: {stats['weak_helped_by_strategy']}\n")
    print(f"Accept strong llm: {stats['strong_llm_used']}")   
    print(f"Number of correct answers of strong llm: {stats['strong_total_correct']}")              

if __name__ == "__main__":

    # weak llm using GPT-3.5 Turbo
    weak_llm_api_key = ""
    # strong llm using Gemini 2.0
    strong_llm_project_id = ""

    data_dir = "data/metamath/test_metamath.json"
    data_jung_pipeline_dir = "data/metamath/test_metamath_jung_pipeline.json"
    problems_indices = None # range(1000)

    # weak llm: alpha=0.4, delta=0.6 => threshold: 0.61, coverage: 0.52
    run_pipeline_metamath(data_dir, data_jung_pipeline_dir, problems_indices, weak_llm_api_key, strong_llm_project_id, threshold_weak=0.612, threshold_strong=0)
    stats_pipeline(data_jung_pipeline_dir)
    """
    Total number of problems: 10000
    Uncover problems: 3
    Accept weak llm: 5075
    Number of correct answers of weak llm: 3543
    Number of correct answers of weak llm that are used: 2852
    Number of weak llm helped by strategy: 0
    Accept strong llm: 4922
    Number of correct answers of strong llm: 3740
    Number of correct answers of strong llm that are used: 3740
    """


