from algo.reasoning_adapter import Reasoning_Adapter
from algo.task_adapters.gpt_sampler import ChatCompletionSampler
from utils.loggers import loggers
import os
import re
from algo.task_adapters.benchmark_utils import ANSWER_PATTERN, last_boxed_only_string, remove_boxed, extract_answer_idx

PROMPT = ""

def write_file(file_path, content):
    with open(file_path, "a") as f:
        f.write(content)

def is_correct(a, gt_example, sampler):
    if isinstance(gt_example, str) and gt_example.isdigit():
        gt_answer = str(int(gt_example))
    else:
        gt_answer = str(gt_example)
    
    split_tokens = ["<|im_start|>answer\n", "<|im_start|>"]

    if split_tokens[0] in a:
        a = a.split(split_tokens[0])[-1]
    if split_tokens[1] in a:
        a = a.split(split_tokens[1])[-1]
        if "\n" in a:
            a = "\n".join(a.split("\n")[1:])

    if (box := last_boxed_only_string(a)) is not None:
        try:
            a = remove_boxed(box)
        except:
            write_file("error_s1k_a.txt", a)
            write_file("error_s1k_box.txt", box)
            return False
    # re.DOTALL is key such that newlines are included e.g. if it does `Answer: Here is the solution:\n\n10`
    elif (matches := re.findall(ANSWER_PATTERN, a, re.DOTALL)) != []:
        a = matches[-1]  # Get the last match
    
    if (a.isdigit()) and (gt_answer.isdigit()):
        a = str(int(a)) # 023 -> 23
    elif sampler is not None:
        options = [gt_answer]
        options_str = "[" + ", ".join(["'" + str(o) + "'" for o in options]) + "]"
        idx = extract_answer_idx(sampler, options_str, a)
        if idx != "-1":
            if idx.isdigit():
                idx = int(idx) - 1
                if len(options) > idx >= 0:
                    a = options[idx]
                else:
                    print("Warning: Index out of bounds; leaving answer unchanged\n", a, "\noptions", options_str, "\ndoc['answer']", gt_answer, "\nidx", idx)
            else:
                print("Warning: Processing did not produce integer index\na", a, "\noptions", options_str, "\ndoc['answer']", gt_answer)
    else:
        pass

    # a = int(a == gt_answer)
    # if not(a):
    #     print("Marked incorrect\na " + a + "\ndoc['answer'] " + gt_answer)
    # return a
    return a == gt_answer

def get_accuracy(results):

    if os.getenv("PROCESSOR", "") == "gpt-4o-mini":
        sampler = ChatCompletionSampler(model="gpt-4o-mini")
    else:
        print(f"Unknown processor: {os.getenv('PROCESSOR')}; set 'PROCESSOR=gpt-4o-mini' and 'OPENAI_API_KEY=YOUR_KEY' for best results.")
        raise ValueError(f"MATH requires PROCESSOR atm. AIME is fine without it.")
    
    accuracies = {"correct": 0, "total": 0}
    for ans, gt in zip(results["completions"], results["ground_truths"]):
        correct = is_correct(ans, gt["answer"], sampler)
        accuracies["correct"] += correct
        accuracies["total"] += 1
        loggers["eval"].info(f"\n{'-'*20}\ncompletion:\n{ans}\nground-truth:\n{gt}\ncorrect: {correct}")
    overall_accuracy = accuracies["correct"] / accuracies["total"]
    loggers["eval"].info(f"\n{'-'*20}\nOverall accuracy: {overall_accuracy}")
    return overall_accuracy, 0.0
    

def stop_criterion(input_text):
    return '<|im_end|>' in input_text.lower()

class S1K_Adapter(Reasoning_Adapter):
    
    def __init__(self, prompt, config, accelerator):
        self.config = config
        self.prompt = prompt
        self.accelerator = accelerator

        super().__init__(
                config=config,
                prompt=prompt,
                accelerator=accelerator
            )
 
        self.stop_criterion = stop_criterion
        self.is_correct = is_correct
        self.get_accuracy = get_accuracy
        self.qa_template = config["qa_template"]
        
    
        
    # TODO: reformat the ground truth solution
    def get_positive_ans(self, b):
        assistant_pattern = "<|im_start|>assistant"
        end_pattern = "<|im_end|>"
        text = b["filtered_text"].split(assistant_pattern)[-1].split(end_pattern)[0]
        return [text]
    
    
    def formulate_question(self, b):
        return b['question']

    
    def extract_ground_truth(self, b):
        return b
    
    

