from __future__ import annotations

import argparse
import itertools
import json
import math
import os
from typing import NamedTuple, Optional, Callable

import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

import pickle

from eval.util import last_boxed_only_string, remove_boxed
from qvalue_encoder_v2 import QValueEncoder
from utils import get_state, get_model_device, build_sets_batch, extract_parentheses

TERMINAL = "Complete the solution and present the final answer within \\boxed{}."


def print_mcts_tree(node: MCTSNode, indent: str = "", last: bool = True) -> None:
    # Print current node
    branch = "└── " if last else "├── "
    print(
        f"{indent}{branch}Node {node.id} (Action: {node.action}, Q: {node.calc_q(node.cum_rewards) if node.cum_rewards else 0:.2f})")

    # Print state if exists
    if node.state:
        state_indent = indent + ("    " if last else "│   ")
        for sub in node.state:
            state_branch = "└─→ "
            print(f"{state_indent}{state_branch}[{sub.idx}] Q: {sub.sub_question} | A: {sub.sub_answer}")

    # Prepare indentation for children
    child_indent = indent + ("    " if last else "│   ")

    # Recursively print children
    if node.children:
        for i, child in enumerate(node.children):
            is_last = i == len(node.children) - 1
            print_mcts_tree(child, child_indent, is_last)


def extract_answer(completion: str, task="math"):
    if task in ["math", "gsm8k"]:
        ans = remove_boxed(last_boxed_only_string(completion))
    elif task in ["mmlu", "mmlupro", "arc", "gpqa"]:
        ans = remove_boxed(last_boxed_only_string(completion))
        if ans is None or ans.strip() == "":
            ans = extract_parentheses(completion)
    else:
        raise NotImplementedError
    return ans


class SubResult(NamedTuple):
    idx: int
    sub_question: str
    sub_answer: str


class MCTSNode:
    id_iter = itertools.count()

    @classmethod
    def reset_id(cls):
        cls.id_iter = itertools.count()

    def __init__(
            self,
            state: list[SubResult],
            action: str,
            parent: MCTSNode = None,
            task: str = None,
            reward: float = 0.,
            is_terminal: bool = False,
            calc_q: Callable[[list[float]], float] = np.mean
    ):
        assert parent is not None or task is not None, "Either 'parent' or 'task' must be provided, but both are None"

        self.id = next(MCTSNode.id_iter)

        self.cum_rewards: list[float] = []
        self.reward = reward
        self.is_terminal = is_terminal
        self.action = action
        self.state = state
        self.parent = parent

        self.children: list[MCTSNode] = None
        self.calc_q = calc_q
        if parent is None:
            self.depth = 0
            self.task = task
        else:
            self.depth = parent.depth + 1
            self.task = parent.task

    @property
    def Q(self) -> float:
        if len(self.cum_rewards) == 0:
            return self.reward
        else:
            return self.calc_q(self.cum_rewards)


def get_incontext_examples(method="rstar"):
    if method == "rstar":
        return [
            # {
            #     # True
            #     "problem": """Compute the range of $y=|x+7|-|x-2|$.""",
            #     "solution": """### Step 1: Propose an one-step thought.\nTo find the range of $y=|x+7|-|x-2|$, we'll analyze the behavior of the absolute value expressions for different intervals of $x$.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: What are the critical points where the behavior of the absolute value expressions changes?\nA1: The critical points are $x=-7$ and $x=2$, where the expressions inside the absolute values change sign.\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: How does the expression behave for $x<-7$?\nA2: For $x<-7$, both $x+7$ and $x-2$ are negative, so $y=-(x+7)-(-x+2)=-9$.\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: How does the expression behave for $-7\leq x<2$?\nA3: For $-7\leq x<2$, $x+7$ is non-negative and $x-2$ is negative, so $y=(x+7)-(-x+2)=2x+5$.\n\n### Step 5: Propose next sub-question along with its answer.\nQ4: How does the expression behave for $x\geq 2$?\nA4: For $x\geq 2$, both $x+7$ and $x-2$ are non-negative, so $y=(x+7)-(x-2)=9$.\n\n### Step 6: Propose next sub-question along with its answer.\nQ5: What is the range of $y$?\nA5: The range of $y$ is the set of all possible values, which is $[-9, 9]$.\n\n### Step 7: Complete the solution and present the final answer within \\boxed{}.\nWe analyzed the behavior of the absolute value expressions for different intervals of $x$.\nThe expression behaves differently for $x<-7$, $-7\leq x<2$, and $x\geq 2$.\nThe range of $y$ is the set of all possible values, which is $[-9, 9]$.\n\nThe final answer is: $\\boxed{[-9, 9]}$""",
            #     "critique": """The solution accurately identifies the critical points and correctly analyzes the expression \( y = |x+7| - |x-2| \) across all relevant intervals. The calculations for each interval are precise, and the conclusion that the range of \( y \) is \([-9, 9]\) is well-supported. There are no evident errors or omissions in the reasoning.\n\nThe solution is \\boxed{correct}."""
            # },
            {
                # True
                "problem": """Let $f(x)=3x-2$, and let $g(x)=f(f(f(f(x))))$.  If the domain of $g$ is $0\\leq x\\leq 2$, compute the range of $g$.""",
                "solution": """### Step 1: Propose an one-step thought.\nTo find the range of $g$, we'll first determine the expression for $g(x)$ by composing $f(x)$ with itself four times.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: What is the expression for $f(f(x))$?\nA1: We have $f(f(x)) = f(3x-2) = 3(3x-2)-2 = 9x-8$.\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: What is the expression for $f(f(f(x)))$?\nA2: We have $f(f(f(x))) = f(9x-8) = 3(9x-8)-2 = 27x-26$.\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: What is the expression for $g(x) = f(f(f(f(x))))$?\nA3: We have $g(x) = f(27x-26) = 3(27x-26)-2 = 81x-80$.\n\n### Step 5: Propose next sub-question along with its answer.\nQ4: What is the range of $g$ when the domain is $0\leq x\leq 2$?\nA4: We substitute the domain bounds into $g(x) = 81x-80$ to find the range:\n- For $x=0$, $g(0) = 81(0)-80 = -80$.\n- For $x=2$, $g(2) = 81(2)-80 = 82$.\nSo, the range of $g$ is $-80\leq g(x)\leq 82$.\n\n### Step 6: Complete the solution and present the final answer within \\boxed{}.\nWe found the expression for $g(x)$ by composing $f(x)$ with itself four times.\nThen, we determined the range of $g$ by substituting the domain bounds into $g(x)$.\nThe range of $g$ is $-80\leq g(x)\leq 82$.\n\nThe final answer is: $\\boxed{[-80, 82]}$""",
                "critique": """The solution correctly composes the function \( f(x) = 3x - 2 \) four times to obtain \( g(x) = 81x - 80 \). It accurately substitutes the domain endpoints \( x = 0 \) and \( x = 2 \) into \( g(x) \) to find the range as \([-80, 82]\). The reasoning is clear and mathematically sound.\n\nThe solution is \\boxed{correct}."""
            },
            {
                # True
                "problem": """Find the coordinates of the point halfway between the points $(3,7)$ and $(5,1)$.""",
                "solution": """### Step 1: Propose an one-step thought.\nTo find the coordinates of the point halfway between the points $(3,7)$ and $(5,1)$, we'll use the midpoint formula.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: What is the midpoint formula?\nA1: The midpoint formula is given by $(x_m, y_m) = \left( \\frac{x_1 + x_2}{2}, \\frac{y_1 + y_2}{2} \\right)$, where $(x_1, y_1)$ and $(x_2, y_2)$ are the coordinates of the two points.\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: How can we apply the midpoint formula to the given points?\nA2: We can substitute the coordinates of the given points into the midpoint formula: $(x_m, y_m) = \left( \\frac{3 + 5}{2}, \\frac{7 + 1}{2} \\right)$.\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: What are the coordinates of the midpoint?\nA3: Evaluating the expression, we get $(x_m, y_m) = \left( \\frac{8}{2}, \\frac{8}{2} \\right) = (4, 4)$.\n\n### Step 5: Complete the solution and present the final answer within \\boxed{}.\nWe used the midpoint formula to find the coordinates of the point halfway between the points $(3,7)$ and $(5,1)$.\nThe midpoint formula is given by $(x_m, y_m) = \left( \\frac{x_1 + x_2}{2}, \\frac{y_1 + y_2}{2} \\right)$.\nSubstituting the coordinates of the given points, we found the midpoint to be $(4, 4)$.\nThe final answer is: $\\boxed{(4, 4)}$""",
                "critique": """The solution correctly applies the midpoint formula by averaging the x-coordinates and the y-coordinates of the given points. Each step is clearly explained and logically leads to the final answer. The calculations are accurate, resulting in the midpoint $(4, 4)$. There are no errors or omissions in the solution.\n\nThe solution is \\boxed{correct}."""
            },
            {
                # False
                "problem": """Tim wants to invest some money in a bank which compounds quarterly with an annual interest rate of $7\\%$. To the nearest dollar, how much money should he invest if he wants a total of $\\$60,\\!000$ at the end of $5$ years?""",
                "solution": """### Step 1: Propose an one-step thought.\nTo find the initial investment, we'll use the compound interest formula: $A=P\left(1+\frac{r}{n}\right)^{nt}$, where $A$ is the future value, $P$ is the principal (initial investment), $r$ is the annual interest rate, $n$ is the number of times compounded per year, and $t$ is the time in years.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: What are the given values in the problem?\nA1: We have $A=60000$, $r=0.07$, $n=4$ (quarterly compounding), and $t=5$ years.\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: How can we rearrange the compound interest formula to solve for $P$?\nA2: We can rearrange the formula to solve for $P$: $P=\frac{A}{\left(1+\frac{r}{n}\right)^{nt}}$.\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: What is the value of $P$ when we plug in the given values?\nA3: Plugging in the values, we get $P=\frac{60000}{\left(1+\frac{0.07}{4}\right)^{4 \cdot 5}} \approx 41587.28$.\n\n### Step 5: Propose next sub-question along with its answer.\nQ4: How much should Tim invest to the nearest dollar?\nA4: Rounding to the nearest dollar, Tim should invest approximately $41587.\n\n### Step 6: Complete the solution and present the final answer within \\boxed{}.\nWe used the compound interest formula to find the initial investment.\nGiven the values A = 60000, r = 0.07, n = 4, and t = 5, we calculated P ≈ 41587.28.\nRounding to the nearest dollar, the final answer is: $\\boxed{41587}$""",
                "critique": """The solution incorrectly calculates the compound factor. The correct computation of \((1 + \frac{0.07}{4})^{20}\) is approximately \(1.416\), which leads to \(P = \frac{60000}{1.416} \approx 42372\), not \(41587\). Additionally, the final rounding should reflect this accurate calculation. Therefore, the solution contains errors in the computation of the compound interest formula.\n\nThe solution is \\boxed{incorrect}."""
            },
            {
                # False
                "problem": """What are all values of $p$ such that for every $q>0$, we have   $$\\frac{3(pq^2+p^2q+3q^2+3pq)}{p+q}>2p^2q?$$ Express your answer in interval notation in decimal form.""",
                "solution": """### Step 1: Propose an one-step thought.\nTo find the values of $p$ that satisfy the given inequality for every $q>0$, we'll analyze and simplify the inequality.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: How can we simplify the given inequality?\nA1: Factor the numerator and simplify the inequality:\n$$\\frac{3q(p+3)(q+p)}{p+q}>2p^2q$$\nCancel out the common factor $(p+q)$:\n$$3q(p+3)>2p^2q$$\nDivide both sides by $q$:\n$$3(p+3)>2p^2$$\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: How can we rewrite the inequality in a more manageable form?\nA2: Expand and rearrange the inequality:\n$$3p+9>2p^2$$\n$$2p^2-3p-9<0$$\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: How can we solve the quadratic inequality?\nA3: Factor the quadratic expression:\n$$(2p+3)(p-3)<0$$\nFind the critical points:\n$$2p+3=0 \Rightarrow p=-\\frac{3}{2}$$\n$$p-3=0 \Rightarrow p=3$$\nSince the inequality is less than $0$, the solution is:\n$$-\\frac{3}{2}<p<3$$\n\n### Step 5: Complete the solution and present the final answer within \\boxed{}.\nWe simplified the given inequality and solved the quadratic inequality.\nThe values of $p$ that satisfy the inequality are $-\\frac{3}{2}<p<3$.\n\nThe final answer is: $\\boxed{(-1.5, 3)}$""",
                "critique": """The solution incorrectly allows values of \( p \) less than 0. When simplifying the inequality, the cancellation of \( (p + q) \) assumes that \( p + q > 0 \) for all \( q > 0 \), which is only true if \( p \geq 0 \). Therefore, the valid interval for \( p \) should be restricted to \( [0, 3) \), not \( (-1.5, 3) \).\n\nThe solution is \\boxed{incorrect}.""",
            },
            # {
            #     # False
            #     "problem": """The real numbers $a$ and $b$ satisfy\n\\[\\begin{pmatrix} 2 \\\\ a \\\\ -7 \\end{pmatrix} \\times \\begin{pmatrix} 5 \\\\ 4 \\\\ b \\end{pmatrix} = \\mathbf{0}.\\]Enter the ordered pair $(a,b).$""",
            #     "solution": """### Step 1: Propose an one-step thought.\nTo find the ordered pair $(a,b)$, we'll use the cross product of two vectors and set it equal to the zero vector.\n\n### Step 2: Propose next sub-question along with its answer.\nQ1: What is the cross product of two vectors?\nA1: The cross product of two vectors $\mathbf{u} = (u_1, u_2, u_3)$ and $\mathbf{v} = (v_1, v_2, v_3)$ is given by\n\[\mathbf{u} \\times \mathbf{v} = \\begin{pmatrix} u_2v_3 - u_3v_2 \\ u_3v_1 - u_1v_3 \\ u_1v_2 - u_2v_1 \end{pmatrix}.\]\n\n### Step 3: Propose next sub-question along with its answer.\nQ2: What is the cross product of the given vectors?\nA2: The cross product of the given vectors is\n\[\\begin{pmatrix} 2 \\ a \\ -7 \end{pmatrix} \\times \\begin{pmatrix} 5 \\ 4 \\ b \end{pmatrix} = \\begin{pmatrix} ab + 28 \\ -2b - 5 \\ 8 - 5a \end{pmatrix}.\]\n\n### Step 4: Propose next sub-question along with its answer.\nQ3: What is the condition for the cross product to be the zero vector?\nA3: For the cross product to be the zero vector, we must have\n\[ab + 28 = 0,\]\n\[-2b - 5 = 0,\]\n\[8 - 5a = 0.\]\n\n### Step 5: Propose next sub-question along with its answer.\nQ4: How can we solve the system of equations?\nA4: Solving the second equation, we get $b = -\\frac{5}{2}$. Substituting this into the first equation, we get $a = \\frac{28}{5/2} = \\frac{56}{5}$. Substituting this into the third equation, we get $8 - 5\left(\\frac{56}{5}\\right) = 0$, which is true.\n\n### Step 6: Complete the solution and present the final answer within \\boxed{}.\nWe used the cross product of two vectors and set it equal to the zero vector. We solved the system of equations and found the ordered pair $(a,b) = \left(\\frac{56}{5}, -\\frac{5}{2}\\right)$.\n\nThe final answer is: $\\boxed{\left(\\frac{56}{5}, -\\frac{5}{2}\\right)}$""",
            #     "critique": """The solution incorrectly computes the second component of the cross product as \(-2b - 5\) instead of the correct \(-35 - 2b\). This error propagates to the system of equations, leading to an incorrect ordered pair \(\left(\\frac{56}{5}, -\\frac{5}{2}\\right)\). The correct system should solve \(ab + 28 = 0\), \(-35 - 2b = 0\), and \(8 - 5a = 0\), resulting in a different solution for \((a, b)\).\n\nThe solution is \\boxed{incorrect}."""
            # },
        ]


def get_terminal_action(method="rstar"):
    if method == "rstar":
        return TERMINAL


def calc_reward(judges):
    corrects = 0
    incorrects = 0
    for judge in judges:
        if "the solution is correct." in judge.lower() or "the proposed solution is correct." in judge.lower():
            corrects += 1
        elif "the solution is incorrect." in judge.lower() or "the proposed solution is incorrect." in judge.lower():
            incorrects += 1
        else:
            continue
    return corrects / max(1, incorrects + corrects)


def save_root_nodes(root_list, filepath):
    with open(filepath, 'wb') as f:
        pickle.dump(root_list, f)


# To recover root nodes from a file
def load_root_nodes(filepath):
    with open(filepath, 'rb') as f:
        return pickle.load(f)


def load_submodular_model(base_model_path,
                          lora_path="qlearning/1223_qlearning_v1/checkpoints/step-9000",
                          eos_token_id=128009):
    model = QValueEncoder(
        llama_model_path=base_model_path,
        lora_path=lora_path,
        tau=0.8,
        pooling_method="lasttoken",
        eos_token_id=eos_token_id,
        is_trainable=False
    )
    model = model.cuda()
    model = nn.DataParallel(model)
    return model


class MCTS:
    def __init__(self,
                 model_path=".../hf_models/Meta-Llama-3.1-8B-Instruct",
                 n_gpus=8,
                 seed=42,
                 temperature=0.6,
                 top_p=1.0,
                 max_len=1024,
                 max_depth: int = 8,
                 min_depth: int = 3,
                 n_confidence: int = 4,
                 n_attempts: int = 8,
                 n_iters: int = 2,
                 w_exp: float = 1.0,
                 submodular_weight: float = 0.9,
                 submodular_size: int = 5,
                 task_type: str = "math",
                 cum_reward: str = "mean",
                 calc_q: str = "max",
                 simulate_strategy: str | Callable[[list[float]], int] = 'max',  # default value
                 disable_tqdm: bool = True):
        self.w_exp = w_exp
        self.max_depth = max_depth
        self.min_depth = min_depth
        self.n_iters = n_iters
        self.cum_reward = np.mean if cum_reward == 'mean' else max
        self.calc_q = np.mean if calc_q == 'mean' else max

        default_simulate_strategies: dict[str, Callable[[list[float]], int]] = {
            'max': lambda x: np.argmax(x),
            'sample': lambda x: np.random.choice(len(x), p=x),
            'random': lambda x: np.random.choice(len(x)),
        }
        self.simulate_choice: Callable[[list[float]], int] = default_simulate_strategies.get(simulate_strategy, simulate_strategy)
        self.root: Optional[MCTSNode] = None
        self.disable_tqdm = disable_tqdm

        self.temperature = temperature
        self.top_p = top_p
        self.max_len = max_len
        self.n_confidence = n_confidence
        self.n_attempts = n_attempts

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = LLM(
            model=model_path,
            tokenizer=model_path,
            tokenizer_mode="slow",
            dtype="bfloat16",
            tensor_parallel_size=n_gpus,
            seed=seed,
            gpu_memory_utilization=0.8,
        )

        self.submodular_tokenizer = AutoTokenizer.from_pretrained(
            ".../hf_models/Llama-3.2-1B-Instruct")
        self.submodular_model = load_submodular_model(
            base_model_path=".../hf_models/Llama-3.2-1B-Instruct",
            lora_path="qlearning/1223_qlearning_v1/checkpoints/step-9000",
            eos_token_id=self.submodular_tokenizer.eos_token_id
        )
        self.submodular_weight = submodular_weight
        self.submodular_size = submodular_size
        self.task_type = task_type

    def is_terminal_with_depth_limit(self, node: MCTSNode):
        return node.is_terminal or node.depth >= self.max_depth

    def uct(self, node: MCTSNode) -> float:
        return node.Q + self.w_exp * np.sqrt(np.log(len(node.parent.cum_rewards)) / max(1, len(node.cum_rewards)))

    def uct_select(self, node: MCTSNode) -> MCTSNode:
        return max(node.children, key=self.uct)

    def get_available_actions(self, node: MCTSNode, method="rstar"):
        if method == "rstar":
            if node.depth < self.min_depth - 1:
                return [
                    "Propose an one-step thought.",
                    # "Propose the remaining thought steps.",
                    "Propose next sub-question along with its answer.",
                    # "Answer the sub-question again.",
                    "Rephrase the question/sub-question.",
                ]
            elif node.depth < self.max_depth - 1:
                return [
                    "Propose an one-step thought.",
                    # "Propose the remaining thought steps.",
                    "Propose next sub-question along with its answer.",
                    # "Answer the sub-question again.",
                    "Rephrase the question/sub-question.",
                    TERMINAL,
                ]
            else:
                return [
                    TERMINAL,
                ]

    def get_submodular_actions(self, nodes: list[MCTSNode]):
        states = [get_state(node.task, [sub_result.sub_question for sub_result in node.state]) for node in nodes]

        action_embeddings = torch.load('action_embeddings_1223.pt')
        actions = action_embeddings["actions"]
        embeddings = action_embeddings["embeddings"]
        embeddings = embeddings.to(get_model_device(self.submodular_model))

        selected_indices = build_sets_batch(
            states,
            encoder=self.submodular_model,
            tokenizer=self.submodular_tokenizer,
            action_embeddings=embeddings,
            ratio=self.submodular_weight,
            set_size=self.submodular_size,
        )
        selected_indices = selected_indices.tolist()
        assert len(selected_indices) == len(nodes)
        selected_actions = []
        for node, batch_indices in zip(nodes, selected_indices):
            if node.is_terminal:
                selected_actions.append([])
            elif node.depth < self.min_depth - 1:
                selected_actions.append([actions[idx] for idx in batch_indices])
            elif node.depth < self.max_depth - 1:
                selected_actions.append([actions[idx] for idx in batch_indices] + [TERMINAL])
            else:
                selected_actions.append([TERMINAL])
        return selected_actions

    def get_solution_prompt(self, task, state, action):
        messages = [
            {
                "role": "user",
                "content": (
                    "Solve this problem step by step, following these rules:\n\n"
                    "1. Begin each step with '### Step'\n"
                    "2. When you see 'Complete the solution and present the final answer within \\boxed{}':\n"
                    "   - Continue solving ALL remaining steps without '### Step'\n" 
                    "   - You MUST enclose your final answer within \\boxed{}\n\n"
                    f"Problem:\n{task}"
                ),
            },
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        for idx, sub_question, sub_answer in state:
            prompt += f"### Step {idx}: {sub_question.strip()}\n{sub_answer.strip()}\n\n"
        prompt += f"### Step {len(state) + 1}: {action}\n"

        # print("\n\n\nsolution prompt: ")
        # print(prompt)
        # input(">>>>>")
        return prompt

    def get_solutions(self, node_list: list[MCTSNode], actions_list: list[list[str]], use_tqdm=None):
        prompts = []
        terminal_list = []
        for node, actions in zip(node_list, actions_list):
            if node.is_terminal:
                assert len(actions) == 0
            for action in actions:
                prompts.append(self.get_solution_prompt(node.task, node.state, action))
                terminal_list.append(action == get_terminal_action(method="rstar"))

        if len(prompts) > 0:
            sampling_params = SamplingParams(
                temperature=self.temperature,
                top_p=self.top_p,
                max_tokens=self.max_len,
                stop=["### Step", self.tokenizer.eos_token, "<|eot_id|>"]
            )
            if use_tqdm is None:
                use_tqdm = not self.disable_tqdm

            with torch.no_grad():
                sub_answer_list = self.model.generate(prompts, sampling_params, use_tqdm=use_tqdm)
                sub_answer_list = [sub_answer.outputs[0].text.strip() for sub_answer in sub_answer_list]

            for _ in range(1, self.n_attempts):
                remaining_indices = [idx for idx in range(len(sub_answer_list)) if terminal_list[idx] and
                                     (extract_answer(sub_answer_list[idx], task=self.task_type) is None or
                                      extract_answer(sub_answer_list[idx], task=self.task_type).strip() == "")]
                if len(remaining_indices) == 0:
                    break
                remaining_prompts = [prompts[idx] for idx in remaining_indices]
                with torch.no_grad():
                    remaining_answer_list = self.model.generate(remaining_prompts, sampling_params, use_tqdm=use_tqdm)
                    remaining_answer_list = [sub_answer.outputs[0].text.strip() for sub_answer in remaining_answer_list]
                for idx, sub_answer in zip(remaining_indices, remaining_answer_list):
                    sub_answer_list[idx] = sub_answer

        else:
            sub_answer_list = []

        solutions_list = []
        indent = 0
        for node, actions in zip(node_list, actions_list):
            solutions_list.append(sub_answer_list[indent: indent + len(actions)])
            indent += len(actions)
        return solutions_list

    def get_reward_prompt(self, task, state, action, solution):
        proposed_solution = ""
        for idx, sub_question, sub_answer in state:
            proposed_solution += f"### Step {idx}: {sub_question.strip()}\n{sub_answer.strip()}\n\n"
        proposed_solution += f"### Step {len(state) + 1}: {action.strip()}\n{solution.strip()}"

        messages = [
            {
                "role": "user",
                "content": (
                    "You are given a problem and a proposed solution.\n\n"
                    f"**Problem:**\n{task}\n\n" 
                    f"**Proposed Solution:**\n{proposed_solution}\n\n"
                    "Your task is to critique if the proposed solution is correct. A solution is correct if:\n"
                    "- It completely solves the problem correctly, OR\n"
                    "- It is a partial solution that could be extended into a complete correct solution\n\n"
                    "Consider a solution incorrect only if it:\n"
                    "- Contains errors OR\n"
                    "- Takes an approach that cannot lead to a correct solution\n\n"
                    "**Conclude your response with EXACTLY ONE of the following statements:**\n"
                    "- \"The solution is correct\" if the solution is complete OR could lead to a correct solution\n"
                    "- \"The solution is incorrect\" if the solution contains errors or cannot lead to a correct solution\n\n"
                    "**This is NOT optional. Your response MUST end with either \"The solution is correct\" or \"The solution is incorrect.\"**\n"
                    "**Limit your response to 200 words.**"
                )
            }
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        prompt += "\n\n### Critique\n"
        # print("\n\n\nreward prompt: ")
        # print(prompt)
        # input(">>>>>")
        return prompt

    def get_rewards(
            self,
            node_list: list[MCTSNode],
            actions_list: list[list[str]],
            solutions_list: list[list[str]],
            use_tqdm=None):
        prompts = []
        for node, actions, solutions in zip(node_list, actions_list, solutions_list):
            if node.is_terminal:
                assert len(actions) == 0 and len(solutions) == 0
            assert len(actions) == len(solutions)
            for action, solution in zip(actions, solutions):
                prompts.append(self.get_reward_prompt(node.task, node.state, action, solution))

        if len(prompts) > 0:
            sampling_params = SamplingParams(
                n=self.n_confidence,
                temperature=self.temperature,
                top_p=self.top_p,
                max_tokens=self.max_len,
                stop=[self.tokenizer.eos_token, "<|eot_id|>"]
            )

            if use_tqdm is None:
                use_tqdm = not self.disable_tqdm

            with torch.no_grad():
                judges_list = self.model.generate(prompts, sampling_params, use_tqdm=use_tqdm)
                judge_list = [judge.text.strip() for judges in judges_list for judge in judges.outputs]

        else:
            judge_list = []

        indent = 0
        rewards_list = []
        for node, actions, solutions in zip(node_list, actions_list, solutions_list):
            rewards = []
            assert len(actions) == len(solutions)
            for action, solution in zip(actions, solutions):
                reward = calc_reward(judge_list[indent: indent + self.n_confidence])
                # if len(node.state) > 0:
                #     print("last state: ", node.state[-1])
                # print("action: ", action)
                # print("---")
                # print("solution: ", solution)
                # print("---")
                # print("critique: ", judge_list[indent])
                # print("---")
                # print("reward: ", reward)
                rewards.append(reward)
                indent += self.n_confidence
            # input(">>>")
            rewards_list.append(rewards)
        return rewards_list

    def dfs_max_reward(self, path: list[MCTSNode]) -> tuple[float, list[MCTSNode]]:
        cur = path[-1]
        if cur.is_terminal:
            return [(cur.reward, len(cur.cum_rewards), cur)]
        if cur.children is None:
            return [(-math.inf, 0, None)]
        visited_children = [x for x in cur.children]
        results = []
        for child in visited_children:
            results += self.dfs_max_reward(path + [child])
        return results

    def select(self, node_list: list[MCTSNode]) -> list[list[MCTSNode]]:
        path_list = []
        for node in node_list:
            path = []
            while True:
                path.append(node)
                if node.children is None:  # if a node has children, it must not be terminal
                    path_list.append(path)
                    break
                node = self.uct_select(node)
        return path_list

    def expand(self, node_list: list[MCTSNode], use_tqdm=None):
        assert all(node.children is None for node in node_list)
        assert all(node.is_terminal or node.depth < self.max_depth for node in node_list)
        # if is_terminal -> skip;
        # if node.depth < max_depth -> expand

        # sub_questions_list = [self.get_available_actions(node) if not node.is_terminal else [] for node in node_list]
        sub_questions_list = self.get_submodular_actions(node_list)
        sub_answers_list = self.get_solutions(node_list, sub_questions_list, use_tqdm)

        rewards_list = self.get_rewards(node_list, sub_questions_list, sub_answers_list, use_tqdm)
        for node, sub_questions, sub_answers, rewards in zip(
                node_list, sub_questions_list, sub_answers_list, rewards_list):
            if node.is_terminal:
                continue

            children = []
            for sub_question, sub_answer, reward in zip(sub_questions, sub_answers, rewards):
                idx = len(node.state) + 1
                child = MCTSNode(
                    state=node.state + [SubResult(idx=idx, sub_question=sub_question, sub_answer=sub_answer)],
                    action=sub_question,
                    parent=node,
                    reward=reward,
                    calc_q=self.calc_q,
                    is_terminal=(sub_question == get_terminal_action(method="rstar"))
                )
                children.append(child)
            node.children = children

    def simulate(self, path_list: list[list[MCTSNode]], iter_num: int):
        node_list = [path[-1] for path in path_list]
        complete_list = [self.is_terminal_with_depth_limit(node) for node in node_list]

        max_steps = self.max_depth
        for _ in trange(max_steps, desc=f"Iteration {iter_num}", leave=True, position=iter_num + 1):
            for i in range(len(node_list)):
                if complete_list[i]:
                    continue
                node = node_list[i]
                rewards = [child.reward for child in node.children]
                node = node.children[self.simulate_choice(rewards)]
                path_list[i].append(node)
                node_list[i] = node
                complete_list[i] = self.is_terminal_with_depth_limit(node)
            if all(complete_list):
                break
            self.expand(node_list)

    def back_propagate(self, path_list: list[list[MCTSNode]]):
        cum_reward_list = []
        for path in path_list:
            rewards = []
            cum_reward = -math.inf
            for node in reversed(path):
                rewards.append(node.reward)
                cum_reward = self.cum_reward(rewards[::-1])
                node.cum_rewards.append(cum_reward)
            cum_reward_list.append(cum_reward)
        return cum_reward_list

    def search(self, tasks: list[str], pickle_path: str):
        MCTSNode.reset_id()
        root_list = [MCTSNode(state=[], action=None, parent=None, task=task, calc_q=self.calc_q) for task in tasks]
        # for root in root_list:
        #     print_mcts_tree(root)
        #     print("========== new tree ===========")
        # input(">>>")

        for iter_num in trange(self.n_iters, desc="MCTS iteration", leave=True, position=0):
            # self.iterate(root_list)
            # print("-----------init-------------")
            # for root in root_list:
            #     print_mcts_tree(root)
            #     print("========== new tree ===========")
            path_list = self.select(root_list)
            self.expand([path[-1] for path in path_list], use_tqdm=True)

            # print("------------expand----------------")
            # for root in root_list:
            #     print_mcts_tree(root)
            #     print("========== new tree ===========")
            self.simulate(path_list, iter_num)

            # print("-------------simulate-------------")
            # for root in root_list:
            #     print_mcts_tree(root)
            #     print("========== new tree ===========")
            self.back_propagate(path_list)
            # print("-------------back-------------")
            # for root in root_list:
            #     print_mcts_tree(root)
            #     print("========== new tree ===========")
            # input(">>>")

        save_root_nodes(root_list, pickle_path)

        reward_list, output_list, answers_list = [], [], []
        for root in root_list:
            results = self.dfs_max_reward([root])

            answer_dict = dict()
            for r, n, node in results:
                if node is None or n == 0:
                    continue
                completion = ""
                for sub_result in node.state:
                    completion += f"### Step {sub_result.idx}: {sub_result.sub_question}\n{sub_result.sub_answer}\n\n"
                completion = completion.strip()

                answer = extract_answer(completion, task=self.task_type)
                if answer in answer_dict:
                    answer_dict[answer].append((r, n, completion))
                else:
                    answer_dict[answer] = [(r, n, completion)]
            sorted_answers = sorted(answer_dict.items(), key=lambda x: sum(r for r, n, _ in x[1]), reverse=True)
            r, n, output = sorted_answers[0][1][0]
            reward_list.append(r)
            output_list.append(output)
            answers_list.append(answer_dict)
        return reward_list, output_list, answers_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate large language models on critical datasets.")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset file.")
    parser.add_argument("--output_path", type=str, required=True, help="Directory to store cached outputs.")
    parser.add_argument("--pickle_path", type=str, required=True, help="Directory to store cached outputs.")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.")
    parser.add_argument("--n_gpus", type=int, default=8, help="Number of GPUs to use for tensor parallelism.")
    parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for generation.")
    parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling for generation.")
    parser.add_argument("--max_len", type=int, default=2048, help="Maximum number of tokens to generate.")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_depth", type=int, default=8)
    parser.add_argument("--min_depth", type=int, default=3)
    parser.add_argument("--n_confidence", type=int, default=8)
    parser.add_argument("--n_iters", type=int, default=8)
    parser.add_argument("--n_attempts", type=int, default=8)
    parser.add_argument("--aggregate_child", type=str, default="max")
    parser.add_argument("--aggregate_reward", type=str, default="mean")
    parser.add_argument("--submodular_weight", type=float, default=0.9)
    parser.add_argument("--submodular_size", type=int, default=5)

    args = parser.parse_args()

    tasks = []
    items = []
    with open(args.data_path, encoding="utf-8") as f:
        for line in f.readlines():
            item = json.loads(line)
            prompt = item["prompt"]
            tasks.append(prompt)
            items.append(item)
            # if len(items) >= 2: break  # todo: del

    task_type = "math"
    if "mmlu" in args.data_path or "gpqa" in args.data_path or "arc" in args.data_path:
        task_type = "mmlu"

    mcts = MCTS(
        model_path=args.model_path,
        n_gpus=args.n_gpus,
        seed=args.seed,
        temperature=args.temperature,
        top_p=args.top_p,
        max_len=args.max_len,
        max_depth=args.max_depth,
        min_depth=args.min_depth,
        n_confidence=args.n_confidence,
        n_iters=args.n_iters,
        n_attempts=args.n_attempts,
        w_exp=1.0,
        submodular_weight=args.submodular_weight,
        submodular_size=args.submodular_size,
        task_type=task_type,
        cum_reward=args.aggregate_child,
        calc_q=args.aggregate_reward,
        simulate_strategy='max',
        disable_tqdm=True,
    )

    reward_list, output_list, answers_list = mcts.search(tasks, args.pickle_path)

    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    with open(args.output_path, "w", encoding="utf-8") as f:
        for item, reward, completion, answers in zip(items, reward_list, output_list, answers_list):
            item["answers"] = answers
            item["completion"] = completion
            item["reward"] = reward
            f.write(json.dumps(item) + "\n")


