import re
from logging import getLogger

import litellm

from pcot.parameters import get_default_max_tokens, get_default_temperature
from pcot.tasks.rsp.bot import Bot, Move

logger = getLogger(__name__)


class LLMBot(Bot):
    def __init__(
        self, litellm_name: str, system_prompt: str, show_history: bool = True
    ):
        self.litellm_name = litellm_name
        self.system_prompt = system_prompt

        self.response_history = []
        self.move_history: list[str] = []
        self.opponent_move_history: list[str] = []
        self.show_history = show_history

        self.output_history = {"output": [], "reasoning": []}

    def reset(self) -> None:
        self.response_history = []
        self.move_history = []
        self.opponent_move_history = []

    @property
    def name(self) -> str:
        return self.litellm_name

    def get_move(self, input: Move) -> str:
        if input != "":
            self.opponent_move_history.append(input)

        move = self.generate_move()
        self.move_history.append(move)
        return move

    def generate_move(self) -> Move:
        req = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt()},
        ]

        kwargs = dict()
        if not ("o4-mini" in self.litellm_name or "gpt-4o" in self.litellm_name):
            if "qwq" in self.litellm_name:
                provider = "DeepInfra"
            else:
                provider = "lambda"

            kwargs["extra_body"] = {
                "provider": {
                    "order": [provider],
                    # "require_parameters": True,
                    "allow_fallbacks": False,
                }
            }

        for _ in range(20):
            try:
                response = litellm.completion(
                    messages=req,
                    model=self.litellm_name,
                    num_retries=3,
                    input_cost_per_token=0,
                    output_cost_per_token=0,
                    temperature=get_default_temperature(self.litellm_name),
                    max_tokens=get_default_max_tokens(self.litellm_name),
                    **kwargs,
                )
                try:
                    logger.info(
                        f"REASONING: \n{response.choices[0].message.reasoning_content}"
                    )
                    self.output_history["reasoning"].append(
                        response.choices[0].message.reasoning_content
                    )
                except:
                    pass

                generated_text: str = response.choices[0].message.content
                logger.info(f"GENERATED: \n{generated_text}")
                self.output_history["output"].append(generated_text)

                answer = self.parse_answer(generated_text)

                if answer == "rock" or answer == "r":
                    return "R"
                elif answer == "scissors" or answer == "s":
                    return "S"
                elif answer == "paper" or answer == "p":
                    return "P"
                else:
                    logger.warning(
                        f"ANSWER is {answer}, FAILED TO PARSE ANSWER FROM: {generated_text}, retrying..."
                    )
            except Exception as e:
                logger.warning(f"error in answer generation: {e}, retrying...")
                continue
        raise RuntimeError("Answer generation failed after 10 attempts")

    def user_prompt(self) -> str:
        prompt = """
You are playing rock, scissors, and paper against other agent. You need to pick a move from rock, scissors or paper.
First provide your reasoning, and then pick your answer inside <answer></answer> tags.
""".strip()

        if self.show_history and (
            len(self.move_history) > 0 or len(self.opponent_move_history) > 0
        ):
            prompt += "\nPrevious rounds' moves are as follows. R, S and P represent rock, scissors and papers, respectively."
            if len(self.move_history) > 0:
                prompt += f"\nYour moves: {', '.join(self.move_history)}"
            if len(self.opponent_move_history) > 0:
                prompt += f"\nOpponent moves: {', '.join(self.opponent_move_history)}"

        return prompt

    def parse_answer(self, generated_text: str) -> str | None:
        # Expect <answer> tag
        matches = re.findall(
            r"<answer>(.*?)</answer>", generated_text, re.DOTALL | re.IGNORECASE
        )
        if matches:
            answer_content = matches[-1].strip().lower()
            return answer_content
        else:
            return None
