import logging
import re
from typing import Any, Optional
from copy import deepcopy

from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
from rllm.agents.trajectory_logger import AgentTrajectoryLogger
from rllm.prompts.reasoning_gym import REASONING_GYM_SYSTEM_PROMPT, REASONING_GYM_USER_PROMPT
from rllm.agents.utils import highlight_changed_text

logger = logging.getLogger(__name__)


class ReasoningGymAgent(BaseAgent, AgentTrajectoryLogger):
    def __init__(self, use_accumulate_thinking: bool = False):
        self._trajectory = Trajectory()
        self.messages = []
        self.step: int = 0
        # state
        self.current_observation = None
        self.use_accumulate_thinking = use_accumulate_thinking
        self.reset()

    def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
        """
        Updates the agent's internal state after an environment step.
        Includes logic to check if the observation changed from the previous step.
        """
        self.current_observation = self._process_observation(observation)

        if done:
            return

        assert self.step == 0, "ReasoningGymAgent is single-turn agent."

        system_prompt = REASONING_GYM_SYSTEM_PROMPT
        user_prompt = REASONING_GYM_USER_PROMPT.format(question=observation["question"])
        
        self.messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

    def update_from_model(self, response, **kwargs) -> Action:
        if response is None:
            response = ""

        thought, action = self._parse_model_response(response)

        # If we don't accumulate thinking, we only want to keep the after the last </think> tag
        if self.use_accumulate_thinking:
            self.messages.append({"role": "assistant", "content": response})
        else:
            self.messages.append({"role": "assistant", "content": action})

        new_step = Step(
            chat_completions=deepcopy(self.chat_completions),
            thought=thought,
            action=action,
            model_response=response,
            observation=self.current_observation,
        )
        self._trajectory.steps.append(new_step)

        self.step += 1

        return Action(action=action)

    def _parse_model_response(self, response: str) -> tuple[str, list[str]]:
        if "THOUGHT" in response:            
            thought_match = re.search(r"THOUGHT: (.*?)", response, re.DOTALL)
        else:
            thought_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)

        if thought_match:
            try:
                thought = thought_match.group(1).strip()
            except Exception as e:
                thought = response
        else:
            thought = ""
        
        answer = ''
        if "</think>" in response:
            for answer_text in ['ANSWER', 'ACTION']:
                if answer_text in response and answer == '':
                    pattern = re.compile(
                        answer_text + r""":\s*
                        (?:                                     # Option A: fenced code block
                            ```(?:[a-zA-Z0-9_-]+)?\s*          # opening fence with optional language
                            (?P<fenced>.*?)                    # content inside fences (non-greedy)
                            \s*```                             # closing fence
                        |
                            (?P<inline>.+)                     # Option B: inline content until end
                        )
                        """,
                        re.DOTALL | re.VERBOSE,
                    )

                    m = pattern.search(response)
                    if m:
                        content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
                        answer = "REASON: " + response.split("</think>")[1].split(answer_text)[0].split("REASON:")[-1].strip() + "\n" + answer_text + ": ```\n" + content.strip() + "\n```"
        
        return thought, answer

    @property
    def chat_completions(self) -> list[dict[str, str]]:
        return self.messages

    @property
    def trajectory(self) -> Trajectory:
        return self._trajectory

    def reset(self) -> None:
        self._trajectory = Trajectory()
        self.messages = []
        self.step = 0

    def _process_observation(self, observation: Any) -> str:
        """
        Process observation.
        """
        return {
            "observation": observation['question'],
            "observation_info": {
                **observation['metadata'],
                'answer': observation['answer'],
                'task_name': observation['task_name'],
            }
        }
    
    # for trajector logger
    def _get_thumbnail(self, observation: dict) -> str:
        if observation['observation_info']['task_name'] == 'leg_counting':
            return observation['observation'].split('how many legs are there in total if you have')[1].strip().rstrip('?')
        return ''

    def _get_visible_obs_info(self, observation: dict) -> dict:
        return {}
