import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import json
import tiktoken
from rag_setup import *
from model.request_gpt import *
from model.request_gemini import *
from model.prompts import *
from find_threshold import *
from utils import *


def run_pipeline_gsm_plus(input_dir, output_dir, problems_indices, weak_llm_api_key, strong_llm_project_id, threshold_weak, threshold_strong, threshold_similarity, similar_problems, random_retrieve=False):

    # read source json file
    with open(input_dir, "r", encoding="utf-8") as f:
        test = json.load(f)

    if problems_indices is None:
        problems_indices = list(range(len(test)))

    # initialize RAG database
    if random_retrieve:
        save_dir = "gsm_plus_solution_strategy_strong_no_strategy_random_strategy"
    else:
        save_dir = "gsm_plus_solution_strategy_strong_no_strategy"
    rag_db = PersistentRAG(save_dir)

    with open("data/gsm_plus/calibration_gsm_plus_strong_gemini2.0flash_llm.json", "r", encoding="utf-8") as f:
        calibration_gemini = json.load(f)
    calibration_gemini = [
        d for d in calibration_gemini 
        if d["strong_llm_strategy"] != "No strategy"
    ]

    if not rag_db.documents: 
        rag_db.add_documents(calibration_gemini)

    # initialize LLMs
    weak_llm_model = "gpt-3.5-turbo"
    weak_llm = ChatGPTClient(api_key=weak_llm_api_key, model=weak_llm_model)
    weak_encoding = tiktoken.encoding_for_model(weak_llm_model)

    strong_llm_model = "gemini-2.0-flash"
    strong_llm = GeminiClient(project_id=strong_llm_project_id, model=strong_llm_model)

    stats = {
        "total_problems": len(problems_indices),
        "weak_helped_by_strategy": 0,
        "weak_total_correct": 0,
        "weak_llm_used": 0,
        "weak_llm_used_correct": 0,
        "strong_total_correct": 0,
        "strong_llm_used": 0,
        "strong_llm_used_correct": 0,
        "unaccept_problems": 0,  
    }

    # identify the parameters of transformation of the confidence_strong on calibration to fit the threshold
    strong_log_confidence = [item["strong_llm_answer_logprob_sum"] for item in calibration_gemini]
    _,  x_min_gemini, rng_gemini = minmax01(exp_from_logprobs(strong_log_confidence))

    # identify the parameters of transformation of the confidence_weak on calibration to fit the threshold
    with open("data/gsm_plus/calibration_gsm_plus_weak_gpt3.5turbo_llm.json", "r", encoding="utf-8") as f:
            calibration_weak = json.load(f)
    calibration_weak = [
        d for d in calibration_weak 
        if d["weak_llm_strategy"] != "No strategy"
    ]
    weak_log_confidence = [item["weak_llm_answer_logprob_sum"] for item in calibration_weak]
    _,  x_min_weak, rng_weak = minmax01(exp_from_logprobs(weak_log_confidence))

    for problem_ind in problems_indices:
        print(f"\nProcessing problem index: {problem_ind}")
        question = test[problem_ind]["question"]

        # retrieve most relevant problems from documents and use them as context
        if random_retrieve:
            relevant_docs = rag_db.random_retrieve(k=similar_problems)
        else:
            relevant_docs = rag_db.retrieve(question, k=similar_problems)
            relevant_docs = [doc for doc in relevant_docs if doc["problem_score"] > threshold_similarity]

        if relevant_docs != []:
            stats["weak_helped_by_strategy"] += 1
            context = format_context(relevant_docs, strategy_key_name="strong_llm_strategy", answer_key_name="strong_llm_answer")
            prompt = fewshot_prompt_gsm.format(context=context, question=question)
        else:
            prompt = zero_prompt_gsm.format(question=question)
        print(f"\nPrompt for llm...\n{prompt}\n")

        weak_llm_response = weak_llm.ask(
            prompt,
            logprobs=True,  # ask for log-probs (must be integer 0-20)
            top_logprobs=1  # we only care about the chosen token
        )
        weak_llm_strategy, weak_llm_answer = extract_response_components(weak_llm_response)
        if weak_llm_answer != "No answer":
            if weak_llm_answer == test[problem_ind]["ground_truth"]:
                correctness = True
            else:
                check_correct = check_correct_prompt.format(llm_answer=weak_llm_answer, ground_truth=test[problem_ind]["ground_truth"])
                check_correct_response = strong_llm.ask(check_correct)
                correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                correctness = correctness_str == "true"
        else:
            correctness = False
        stats["weak_total_correct"] += correctness
        print(f"\nWeak LLM answer is {weak_llm_answer}, and ground truth is {test[problem_ind]['ground_truth']}")
        print(f"Match or not? {correctness}\n")

        # count how many tokens that answer took
        if weak_llm_answer != "No answer":
            weak_answer_ids = weak_encoding.encode(weak_llm_answer, disallowed_special=())
            weak_n_answer_tokens = len(weak_answer_ids)

            # pull out all of the token
            entries = weak_llm_response.logprobs.content
            token_lps = [e.logprob for e in entries]

            # sum the last n_answer_tokens of them
            weak_answer_token_lps = token_lps[-weak_n_answer_tokens:]
            weak_answer_logprob = sum(weak_answer_token_lps)
        else:
            weak_answer_logprob = None

        # update the test data with weak llm answer
        safe_update_data(problem_ind, {
            "weak_llm_helped_by_strategy": True if relevant_docs != [] else False,
            "weak_llm_prompt": prompt,
            "weak_llm_response": weak_llm_response.message.content,
            "weak_llm_answer_tokens": weak_n_answer_tokens if weak_llm_answer != "No answer" else None,
            "weak_llm_strategy": weak_llm_strategy,
            "weak_llm_answer": weak_llm_answer,
            "weak_llm_accuracy": correctness,
            "weak_llm_answer_logprob_sum": weak_answer_logprob,
        }, output_dir)

        if weak_answer_logprob is not None:
            confidence_weak = inv_minmax01(exp_from_logprobs(weak_answer_logprob), x_min_weak, rng_weak)
            print(f"\nWeak LLM answer has confidence {confidence_weak}, threshold for weak LLM is {threshold_weak}")
            
        if weak_answer_logprob is not None and confidence_weak > threshold_weak:
            # accept weak llm answer
            stats["weak_llm_used"] += 1
            stats["weak_llm_used_correct"] += correctness
            print("Accept weak llm answer")

            safe_update_data(problem_ind, {
                "weak_llm_used": True,
                "weak_llm_used_correct": correctness,
            }, output_dir)
        else:
            safe_update_data(problem_ind, {
                "weak_llm_used": False,
                "weak_llm_used_correct": False,
            }, output_dir)

            # send to strong llm
            prompt = zero_prompt_gsm.format(question=question) # no context for strong llm
            strong_llm_response = strong_llm.ask(
                prompt,
                logprobs=True,
                top_logprobs=1
            )
            strong_llm_strategy, strong_llm_answer = extract_response_components(strong_llm_response)
            if strong_llm_answer != "No answer":
                if strong_llm_answer == test[problem_ind]["ground_truth"]:
                    strong_correctness = True
                else:
                    check_correct = check_correct_prompt.format(llm_answer=strong_llm_answer, ground_truth=test[problem_ind]["ground_truth"])
                    check_correct_response = strong_llm.ask(check_correct)
                    strong_correctness_str = check_correct_response.message.content.strip().split()[-1].lower()
                    strong_correctness = strong_correctness_str == "true"
            else:
                strong_correctness = False
            stats["strong_total_correct"] += strong_correctness
            print(f"\nStrong LLM answer is {strong_llm_answer}, and ground truth is {test[problem_ind]['ground_truth']}")
            print(f"Match or not? {strong_correctness}\n")

            if strong_llm_answer != "No answer":
                token_list = strong_llm_response.logprobs.content
                for i, token_obj in enumerate(token_list):
                    if token_obj.token.strip() == "Answer" and token_list[i + 1].token.strip() == "]:":
                        start_index = i + 2
                while start_index < len(token_list) and token_list[start_index].token.strip() == "":
                    start_index += 1
                if token_list[-1].token.strip() == "":
                    end_index = len(token_list) - 1
                else:
                    end_index = len(token_list)
                print("Answer token and logprob ...")
                for i in range(start_index, end_index):
                    print(f"  - Token '{token_list[i].token}' has log-probability: {token_list[i].logprob}")

                token_lps = [token_obj.logprob for token_obj in token_list]
                strong_answer_token_lps = token_lps[start_index:end_index]
                strong_answer_logprob = sum(strong_answer_token_lps)
            else:
                strong_answer_logprob = None

            # update the test data with strong llm answer
            safe_update_data(problem_ind, {
                "strong_llm_prompt": prompt,
                "strong_llm_response": strong_llm_response.message.content,
                "strong_llm_answer_tokens": end_index - start_index if strong_llm_answer != "No answer" else None,
                "strong_llm_strategy": strong_llm_strategy,
                "strong_llm_answer": strong_llm_answer,
                "strong_llm_accuracy": strong_correctness,
                "strong_llm_answer_logprob_sum": strong_answer_logprob,
            }, output_dir)

            if strong_answer_logprob is not None:
                confidence_strong = inv_minmax01(exp_from_logprobs(strong_answer_logprob), x_min_gemini, rng_gemini)
                print(f"\nStrong LLM answer has confidence {confidence_strong}, threshold for strong LLM is {threshold_strong}")

            if strong_answer_logprob is not None and confidence_strong > threshold_strong:
                # accept strong llm answer
                stats["strong_llm_used"] += 1
                stats["strong_llm_used_correct"] += strong_correctness
                print("Accept strong llm answer")

                safe_update_data(problem_ind, {
                    "strong_llm_used": True,
                    "strong_llm_used_correct": strong_correctness
                }, output_dir)

                # add strong llm answer to RAG database
                test[problem_ind]["strong_llm_strategy"] = strong_llm_strategy
                document_to_add = test[problem_ind]
                rag_db.add_documents([document_to_add])
            else:
                stats["unaccept_problems"] += 1
                print("Reject both weak and strong llm answers")

                safe_update_data(problem_ind, {
                    "strong_llm_used": False,
                    "strong_llm_used_correct": False
                }, output_dir)

    print(f"Total number of problems: {stats['total_problems']}")
    print(f"Uncover problems: {stats['unaccept_problems']}\n")
    print(f"Accept weak llm: {stats['weak_llm_used']}")
    print(f"Number of correct answers of weak llm: {stats['weak_total_correct']}")
    print(f"Number of correct answers of weak llm that are used: {stats['weak_llm_used_correct']}")
    print(f"Number of weak llm helped by strategy: {stats['weak_helped_by_strategy']}\n")
    print(f"Accept strong llm: {stats['strong_llm_used']}")   
    print(f"Number of correct answers of strong llm: {stats['strong_total_correct']}")
    print(f"Number of correct answers of strong llm that are used: {stats['strong_llm_used_correct']}")

if __name__ == "__main__":

    # weak llm using GPT-3.5 Turbo
    weak_llm_api_key = ""
    # strong llm using Gemini 2.0
    strong_llm_project_id = ""

    # 1. Strategy based on similarity score
    # strong llm: alpha=0.2, delta=0.8 => threshold: 0.51, coverage: 0.93
    # weak llm: alpha=0.6, delta=0.6 => threshold: 0.48, coverage: 0.49
    data_dir = "data/gsm_plus/test_gsm_plus.json"
    data_our_pipeline_dir = "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy.json"
    problems_indices = None #list(range(8803, 9504)), range(1000)
    run_pipeline_gsm_plus(data_dir, data_our_pipeline_dir, problems_indices, weak_llm_api_key, strong_llm_project_id, threshold_weak=0.482, threshold_strong=0.508, threshold_similarity=0, similar_problems=2)
    stats_pipeline(data_our_pipeline_dir)
    """
    Total number of problems: 9504
    Uncover problems: 495
    Accept weak llm: 5926
    Number of correct answers of weak llm: 3483
    Number of correct answers of weak llm that are used: 3219
    Number of weak llm helped by strategy: 9504
    Accept strong llm: 3083
    Number of correct answers of strong llm: 2062
    Number of correct answers of strong llm that are used: 2034
    """
    #stats_pipeline_category(data_our_pipeline_dir, category_name="perturbation_type")

    # 2. Strategy from random selection
    # strong llm: alpha=0.2, delta=0.8 => threshold: 0.51, coverage: 0.93
    # weak llm: alpha=0.6, delta=0.6 => threshold: 0.48, coverage: 0.49
    data_dir = "data/gsm_plus/test_gsm_plus.json"
    data_our_pipeline_dir = "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy_random_strategy.json"
    problems_indices = None  #list(range(4008, 9504))
    run_pipeline_gsm_plus(data_dir, data_our_pipeline_dir, problems_indices, weak_llm_api_key, strong_llm_project_id, threshold_weak=0.482, threshold_strong=0.508, threshold_similarity=0, similar_problems=2, random_retrieve=True)
    stats_pipeline(data_our_pipeline_dir)
    """
    Total number of problems: 9504
    Uncover problems: 561
    Accept weak llm: 4795
    Number of correct answers of weak llm: 2281
    Number of correct answers of weak llm that are used: 2001
    Number of weak llm helped by strategy: 9504
    Accept strong llm: 4148
    Number of correct answers of strong llm: 2827
    Number of correct answers of strong llm that are used: 2795
    """

