from copy import deepcopy
from typing import List, Optional, Tuple

import numpy as np

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import (
    GenerationRequest,
    GenerationResult,
    Message,
    Model,
)
from llm_mcts.mcts_algo.eval_result import EvalResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.models.aggregated_model import AggregatedModel
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.tasks.base import Task


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


class LLMSolver(AggregatedSolver):
    def __init__(
        self,
        model: Model,
        prompt_template: PromptTemplate,
        task: Task,
    ) -> None:
        self.model = model
        self.prompt_template = prompt_template
        self.task = task

    def generate_child_nodes(
        self,
        node: Node,
        kind: Action,
        num_samples: int,
        scorer: MCTSScorer,
        next_serial_number: int,
    ) -> Tuple[List[Node], int]:
        gen_results, llm_names = self._generate(node, kind, num_samples)

        children: List[Node] = []
        for gen_result, llm_name in zip(gen_results, llm_names):

            eval_results = self.task.generate_eval_results(gen_result, kind=kind)
            next_prompt = self.create_next_prompt(
                kind,
                gen_result,
                eval_results=eval_results,
            )

            child = Node(
                serial_number=next_serial_number,
                next_prompt=next_prompt,
                llm_name=llm_name,
                completion=gen_result,
                parent=node,
                last_action=kind,
                eval_results=eval_results,
            )
            node.children.append(child)
            children.append(child)
            next_serial_number += 1

        # We use child scores for policy and update prior.
        priors = softmax([scorer.get_score(node=child) for child in children])
        for child, prior in zip(children, priors):
            child.prior = prior

        return children, next_serial_number

    def create_next_prompt(
        self,
        kind: Action,
        gen_result: GenerationResult,
        eval_results: Optional[List[EvalResult]],
    ) -> GenerationRequest:
        messages = deepcopy(
            gen_result.request.messages
        )  # Need deepcopy here to avoid mysterious bug that last user message does not contain task instruction sometimes
        messages.append(Message(role="assistant", content=gen_result.generation))
        messages.append(
            Message(
                role="user",
                content=self.prompt_template.feedback_prompt(
                    action=kind,
                    eval_results=eval_results,
                    generation_result=gen_result,
                ),
            )
        )
        return GenerationRequest(messages=messages)

    def _generate(
        self, node: Node, kind: Action, num_samples: int
    ) -> Tuple[List[GenerationResult], List[str]]:
        """
        Generate LLM response for the current problem
        """
        assert node.next_prompt is not None

        prompt = self.prompt_template.add_next_action_instruction(
            action=kind,
            next_prompt=deepcopy(node.next_prompt),
        )
        gen_results, llm_names = self.model.generate_with_llm_names(
            [prompt] * num_samples
        )
        return gen_results, llm_names

    def get_solvers(self) -> List[str]:
        if not isinstance(self.model, AggregatedModel):
            raise NotImplementedError(
                "get_solvers is supported only for AggregatedModel"
            )
        return [model.model_name for model in self.model.models]

    def set_solver(self, solver_name: str) -> None:
        if not isinstance(self.model, AggregatedModel):
            raise NotImplementedError(
                "set_solver is supported only for AggregatedModel"
            )

        new_prob = [
            1.0 if model.model_name == solver_name else 0.0
            for model in self.model.models
        ]
        self.model.model_prob = new_prob
