from __future__ import annotations

from mcts import load_root_nodes, print_mcts_tree, MCTSNode, SubResult
import argparse
import math
from eval.util import last_boxed_only_string, remove_boxed
from qvalue_encoder_v2 import QValueEncoder
from utils import get_state, get_model_device, build_sets_batch, extract_parentheses

import json
from tqdm import tqdm
import re


def dfs_max_reward(path: list[MCTSNode]) -> tuple[float, list[MCTSNode]]:
    cur = path[-1]
    if cur.is_terminal:
        return [(cur.reward, len(cur.cum_rewards), cur)]
    if cur.children is None:
        return [(-math.inf, 0, None)]
    visited_children = [x for x in cur.children]
    results = []
    for child in visited_children:
        results += dfs_max_reward(path + [child])
    return results


def extract_parentheses_new(text):
    pattern = r'The correct answer is \((.*?)\)'
    matches = re.finditer(pattern, text)
    last_match = None
    for match in matches:
        last_match = match
    return last_match.group(1) if last_match else ''


def get_letter_list(num_letters=26):
    return [chr(i) for i in range(65, 65 + min(num_letters, 26))]


def extract_answer_new(completion: str, task="math"):
    if task in ["math", "gsm8k"]:
        ans = remove_boxed(last_boxed_only_string(completion))
    elif task in ["mmlu", "mmlupro", "arc", "gpqa"]:
        ans = extract_parentheses(completion)
        if ans is None or ans.strip() not in get_letter_list():
            ans = extract_parentheses_new(completion)
        if ans is None or ans.strip() not in get_letter_list():
            ans = None
    else:
        raise NotImplementedError
    return ans


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Evaluate large language models on critical datasets.")
    parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.")
    parser.add_argument("--pickle_path", type=str, required=True, help="Directory to store cached outputs.")
    args = parser.parse_args()

    task_type = "math"
    if "mmlu" in args.output_path or "gpqa" in args.output_path or "arc" in args.output_path:
        task_type = "mmlu"

    root_list = load_root_nodes(args.pickle_path)

    reward_list, output_list, answers_list = [], [], []
    for root in tqdm(root_list):
        results = dfs_max_reward([root])

        answers = dict()
        for reward, visit, node in results:
            if node is None or visit == 0:
                continue
            completion = ""
            for sub_result in node.state:
                completion += f"### Step {sub_result.idx}: {sub_result.sub_question}\n{sub_result.sub_answer}\n\n"
            completion = completion.strip()

            # answer = remove_boxed(last_boxed_only_string(completion))
            answer = extract_answer_new(completion, task_type)
            if answer in answers:
                answers[answer].append((reward, visit, completion))
            else:
                answers[answer] = [(reward, visit, completion)]

        # sorted_answers = sorted(answers.items(), key=lambda x: sum(r for r, n, _ in x[1]), reverse=True)  # 61.0
        sorted_answers = sorted(answers.items(), key=lambda x: sum(n for r, n, _ in x[1]), reverse=True) # 60.2
        # sorted_answers = sorted(answers.items(), key=lambda x: sum(r * n for r, n, _ in x[1]), reverse=True) # 60.2
        # sorted_answers = sorted(answers.items(), key=lambda x: sum(1 for r, n, _ in x[1]), reverse=True) # 59.6


        # for ans, completions in sorted_answers:
        #     print(f"ans: {ans}; num: {len(completions)}")
        #     print("---")
        #     for idx, comp in enumerate(completions):
        #         print(f"idx: {idx} - {comp}")
        #     print("===")
        # input(">>>")

        reward, n, output = sorted_answers[0][1][0]
        reward_list.append(reward)
        output_list.append(output)
        answers_list.append(answers)

    items = []
    with open(args.output_path, encoding="utf-8") as f:
        for line in f.readlines():
            items.append(json.loads(line))

    for item, reward, completion in zip(items, reward_list, output_list):
        item["reward"] = reward
        item["completion"] = completion

    with open(args.output_path, "w", encoding="utf-8") as f:
        for item in items:
            f.write(json.dumps(item) + "\n")


