import inspect
from typing import Dict, Optional

import numpy as np
from llm_mcts.mcts_algo.algo_base import MockSolver
from llm_mcts.mcts_algo.eval_result import EvalResult
from llm_mcts.mcts_algo.hierarchical_thompson.algo import HierarchicalThompsonAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.thompson_state import (
    GaussianPrior,
    HierarchicalThompsonState,
    PriorConfig,
)
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.node import Node as LLMMCTSNode
from llm_mcts.mcts_algo.pymc_mixed.algo import PyMCMixedAlgo, PyMCMixedAlgoConfig
from llm_mcts.mcts_algo.pymc_mixed.pymc_interface import Observation
from llm_mcts.mcts_algo.score_funcs import UCTScore, UCTConfig
from llm_mcts.mcts_algo.standard.algo import StandardAlgo
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.mcts_scorer.default import DefaultScorer

from .journal import Journal
from .journal import Node as AideNode
from .utils.metric import MetricValue

NEGATIVE_INF = -8

MOCK_LLM_NAME = MockSolver().get_solvers()[0]


class PlainEvalResult(EvalResult):
    def __init__(self, node: AideNode):
        self.metric: MetricValue = node.metric

    def get_score(self) -> float:
        """
        Currently the range of score is [-inf, inf] (no restriction on its range)
        TODO: Check if we get better result when rescaling the metric value to [0, 1]
        """
        if self.metric.is_worst:
            return NEGATIVE_INF

        metric_value = self.metric.value
        if isinstance(metric_value, np.ndarray):
            score = float(np.average(metric_value))
        else:
            score = float(metric_value)

        if self.metric.maximize:
            return score
        else:
            # In case lower metric value is better, negate the metric value to calculate the score
            return -score


def map_tree(
    node: AideNode,
    node_table: Dict[LLMMCTSNode, AideNode],
    parent: Optional[LLMMCTSNode] = None,
) -> LLMMCTSNode:
    """
    copy over the tree structure (i.e. parent-child) and node scores with dfs
    """
    llm_mcts_node = LLMMCTSNode(
        serial_number=node.step,
        next_prompt=None,
        parent=parent,
        eval_results=[PlainEvalResult(node)],
        last_action="answer",
        llm_name=MOCK_LLM_NAME,
    )
    node_table[llm_mcts_node] = node

    for child in node.children:
        # For final condition, due to pool process, the node which is not yet evaluated need to be skipped
        if not (
            child.is_buggy
            or (child.metric is not None and child.metric.is_worst)
            or child.metric is None
        ):
            llm_mcts_node.children.append(map_tree(child, node_table, llm_mcts_node))

    return llm_mcts_node


def update_scores_thompson(
    node: LLMMCTSNode,
    algo: HierarchicalThompsonAlgo,
    scorer: MCTSScorer,
    prior_config: PriorConfig,
) -> LLMMCTSNode:
    model_name = MOCK_LLM_NAME  # NOTE: Currently we only support a single model

    if node not in algo.thompson_states and node.is_expanded():
        algo.thompson_states[node] = HierarchicalThompsonState(
            model_names=[model_name], prior_config=prior_config
        )

    if not node.is_root():
        algo.thompson_states[node.parent].add_new_node(model_name=model_name, node=node)
        algo.backpropagate(node, scorer, model_name=model_name)

    for child in node.children:
        update_scores_thompson(child, algo, scorer, prior_config)
    return node


def add_observations(
    algo: PyMCMixedAlgo, root: LLMMCTSNode, scorer: MCTSScorer, model_name: str
) -> None:
    if not root.is_root():
        algo.all_observations[root] = Observation(
            reward=scorer.get_score(root), model_name=model_name, node=root
        )

    for child in root.children:
        add_observations(algo, child, scorer, model_name)


def update_scores_pymc_thompson(
    node: LLMMCTSNode,
    algo: HierarchicalThompsonAlgo,
    scorer: MCTSScorer,
) -> LLMMCTSNode:
    model_name = MOCK_LLM_NAME  # NOTE: Currently we only support a single model
    add_observations(algo, node, scorer, model_name)
    return node


def find_next_node_thompson(journal: Journal) -> Optional[AideNode]:
    # Create a new root where children are draft nodes
    aide_root = AideNode(code="", children=set(journal.draft_nodes))
    # Aide tree to LLM-MCTS tree, draft nodes will be regarded as children of the root node
    node_table: Dict[LLMMCTSNode, AideNode] = dict()
    root = map_tree(aide_root, node_table)

    ########### LLM-MCTS config part start
    prior_config = PriorConfig(
        dist_type="gaussian", prior=GaussianPrior(m=0, kappa=1, nu=1, tau_square=0.1)
    )
    algo = HierarchicalThompsonAlgo(
        config=MCTSConfig(actions=("answer",)), prior_config=prior_config
    )
    scorer = DefaultScorer()
    ########### LLM-MCTS config part end

    root = update_scores_thompson(root, algo, scorer=scorer, prior_config=prior_config)

    next_node = node_table[algo.next_node_to_generate_child(root)]
    if next_node is aide_root:
        return None
    else:
        return next_node


def find_next_node_pymc(journal: Journal) -> Optional[AideNode]:
    # Create a new root where children are draft nodes
    aide_root = AideNode(code="", children=set(journal.draft_nodes))
    # Aide tree to LLM-MCTS tree, draft nodes will be regarded as children of the root node
    node_table: Dict[LLMMCTSNode, AideNode] = dict()
    root = map_tree(aide_root, node_table)

    ########### LLM-MCTS config part start
    # Pruning is only for the case where scores are degenerate, so we set it to be False here (We might change it later)
    pymc_config = PyMCMixedAlgoConfig(enable_pruning=False, algo="thompson")
    algo = PyMCMixedAlgo(
        config=MCTSConfig(actions=("answer",)),
        pymc_config=pymc_config,
    )
    scorer = DefaultScorer()
    ########### LLM-MCTS config part end

    root = update_scores_pymc_thompson(root, algo, scorer=scorer)

    next_node = node_table[algo.next_node_to_generate_child(root)]
    if next_node is aide_root:
        return None
    else:
        return next_node


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


def update_scores_standard(
    node: LLMMCTSNode,
    algo: StandardAlgo,
    scorer: MCTSScorer,
) -> LLMMCTSNode:
    if len(node.children) > 0:
        # We use child scores for policy and update prior.
        priors = softmax([scorer.get_score(node=child) for child in node.children])
        for child, prior in zip(node.children, priors):
            child.prior = prior

        for child in node.children:
            update_scores_standard(child, algo, scorer)

    if (not node.is_root()) and node.is_expanded():
        node.backpropagate(scorer.get_score(node=node))

    return node


def find_next_node_standard(journal: Journal) -> Optional[AideNode]:
    # Create a new root where children are draft nodes
    aide_root = AideNode(code="", children=set(journal.draft_nodes))
    # Aide tree to LLM-MCTS tree, draft nodes will be regarded as children of the root node
    node_table: Dict[LLMMCTSNode, AideNode] = dict()
    root = map_tree(aide_root, node_table)

    ########### LLM-MCTS config part start
    algo = StandardAlgo(
        config=MCTSConfig(
            actions=("answer",),
            # Those parameters are not used, but we set it here to avoid AssertionError
            num_expand_samples=1,
            num_simulations=1,
            initial_expand_samples=1,
        ),
        score_func=UCTScore(UCTConfig(ucb_c=0.1)),
    )
    scorer = DefaultScorer()
    ########### LLM-MCTS config part end

    root = update_scores_standard(root, algo, scorer)

    next_node = node_table[algo.next_node_to_generate_child(root)]
    if next_node is aide_root:
        return None
    else:
        return next_node
