from algo.reasoning_adapter import Reasoning_Adapter
from algo.task_adapters.gpt_sampler import ChatCompletionSampler
from utils.loggers import loggers
import os
import re
from algo.task_adapters.benchmark_utils import ANSWER_PATTERN, last_boxed_only_string, remove_boxed, extract_answer

PROMPT = ""
QUERY_TEMPLATE_API = "{Question}\nAnswer Choices:\n(A) {choice1}\n(B) {choice2}\n(C) {choice3}\n(D) {choice4}"

def write_file(file_path, content):
    with open(file_path, "a") as f:
        f.write(content)

def is_correct(a, gt_example, sampler, question):
    split_tokens = ["<|im_start|>answer\n", "<|im_start|>"]

    if split_tokens[0] in a:
        a = a.split(split_tokens[0])[-1]
    if split_tokens[1] in a:
        a = a.split(split_tokens[1])[-1]
        if "\n" in a:
            a = "\n".join(a.split("\n")[1:])

    if (box := last_boxed_only_string(a)) is not None:
        try:
            a = remove_boxed(box)
        except:
            write_file("error_gpqa.txt", a)
            write_file("error_gpqa.txt", '\n\n\n')
            write_file("error_gpqa.txt", box)
            return False
    # re.DOTALL is key such that newlines are included e.g. if it does `Answer: Here is the solution:\n\n10`
    elif (matches := re.findall(ANSWER_PATTERN, a, re.DOTALL)) != []:
        a = matches[-1]  # Get the last match

    if a in ["a", "b", "c", "d"]:
        a = a.upper()
    
    if a not in ["A", "B", "C", "D"]:
        if sampler is not None:
            a = extract_answer(sampler, question, a)
        else:
            pass # TODO: Maybe add back legacy processing
    
    if a not in ["A", "B", "C", "D"]:
        print(f"Warning: Default to A as extracted {a}")
        a = "A"

    return a == gt_example

def get_accuracy(results):

    if os.getenv("PROCESSOR", "") == "gpt-4o-mini":
        sampler = ChatCompletionSampler(model="gpt-4o-mini")
    else:
        print(f"Unknown processor: {os.getenv('PROCESSOR')}; set 'PROCESSOR=gpt-4o-mini' and 'OPENAI_API_KEY=YOUR_KEY' for best results.")
        sampler = None
    
    accuracies = {"correct": 0, "total": 0}
    for ans, gt in zip(results["completions"], results["ground_truths"]):
        question = QUERY_TEMPLATE_API.format(Question=gt["Question"], choice1=gt["choice1"], choice2=gt["choice2"], choice3=gt["choice3"], choice4=gt["choice4"])
        correct = is_correct(ans, gt["answer"], sampler, question)
        accuracies["correct"] += correct
        accuracies["total"] += 1
        loggers["eval"].info(f"\n{'-'*20}\ncompletion:\n{ans}\nground-truth:\n{gt}\ncorrect: {correct}")
    overall_accuracy = accuracies["correct"] / accuracies["total"]
    loggers["eval"].info(f"\n{'-'*20}\nOverall accuracy: {overall_accuracy}")
    return overall_accuracy, 0.0
    

def stop_criterion(input_text):
    return '<|im_end|>' in input_text.lower()

class GPQA_Adapter(Reasoning_Adapter):
    
    def __init__(self, prompt, config, accelerator):
        self.config = config
        self.prompt = prompt
        self.accelerator = accelerator

        super().__init__(
                config=config,
                prompt=prompt,
                accelerator=accelerator
            )
 
        self.stop_criterion = stop_criterion
        self.is_correct = is_correct
        self.get_accuracy = get_accuracy
        self.qa_template = config["qa_template"]
        
    
        
    # TODO: reformat the ground truth solution
    def get_positive_ans(self, b):
        pass
    
    
    # TODO: need to change this
    def formulate_question(self, b):
        question_str = f"{b['Question']}\n\nA) {b['choice1']}\nB) {b['choice2']}\nC) {b['choice3']}\nD) {b['choice4']}" 
        return question_str

    
    def extract_ground_truth(self, b):
        return b
    
    
