import functools
import re
from typing import Callable, List, Optional, Tuple

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithScore
from llm_mcts.tasks.base import Task

INVALID_SCORE = -1e12


class ScoreArithmeticGame(Task):
    """
    The base class for the score arithmetic game (A toy example)
    """

    def __init__(
        self,
        num_rounds: int,
        reducer: Callable[[Optional[float], Optional[float]], Optional[float]],
    ) -> None:
        self.num_rounds = num_rounds
        self.reducer = reducer

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResult]]:
        score = self.reduce_scores(llm_answer)
        return [EvalResultWithScore(score=score, reason="")]

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResult], float]:
        """
        Test score is the same as eval score
        """
        score = self.reduce_scores(llm_answer)
        return [EvalResultWithScore(score=score, reason="")], score

    def reduce_scores(self, llm_answer: GenerationResult) -> float:
        scores = []
        # Collect all the past scores
        for message in llm_answer.request.messages:
            if message.role == "assistant":
                score = self.parse_score(message.content)
                scores.append(score)

        scores.append(self.parse_score(llm_answer.generation))

        # Only consider the most recent num_rounds scores
        total_scores = functools.reduce(self.reducer, scores[-self.num_rounds :])

        return total_scores if total_scores is not None else INVALID_SCORE

    def parse_score(self, score_str: str) -> Optional[float]:
        score = re.search(r"```score(.*)```", score_str, re.DOTALL | re.MULTILINE)
        if score is None:
            return None
        try:
            return float(score.group(1).strip())
        except Exception:
            return None
