import logging
import re
from typing import Any, Optional
from copy import deepcopy

from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
from rllm.agents.trajectory_logger import AgentTrajectoryLogger
from rllm.prompts.sudoku_agent import SUDOKU_SYSTEM_PROMPT, SUDOKU_USER_PROMPT, SUDOKU_NON_THINKING_SYSTEM_PROMPT, SUDOKU_NON_THINKING_USER_PROMPT, GEPA_SUDOKU_PROMPT_SET
from rllm.agents.utils import remove_concepts_from_messages
from rllm.pipeline.baseline.gepa.gepa_handler import GEPAHandler
from rllm.agents.utils import highlight_changed_text

logger = logging.getLogger(__name__)


class SudokuAgent(BaseAgent, AgentTrajectoryLogger):
    gepa_prompt_set = GEPA_SUDOKU_PROMPT_SET
    
    def __init__(
        self, 
        max_steps: int = 5, 
        use_accumulate_thinking: bool = False, 
        use_accumulate_history: bool = True, 
        history_window: int | None = None, 
        use_summarize_history: bool = False, 
        use_non_thinking: bool = False, 
        use_concept: bool = True,
        gepa_handler: Optional[GEPAHandler] = None
    ):
        self._trajectory = Trajectory()
        self.messages = []
        self.step: int = 0
        self.use_accumulate_thinking = use_accumulate_thinking  # controlls whether to accumulate the thinking portion of the response
        self.max_steps = max_steps
        self.accumulate_history = use_accumulate_history
        self.history_window = history_window
        self.summarize_history = use_summarize_history
        self.use_non_thinking = use_non_thinking
        self.use_concept = use_concept
        # state
        self.current_observation = None
        self.gepa_handler: Optional[GEPAHandler] = gepa_handler  # if not None, we are running GEPA baseline.

        assert not (bool(self.gepa_handler) and self.use_concept), "In GEPA baseline, `use_concept` should be False."

        self.reset()

    def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
        """
        Updates the agent's internal state after an environment step.
        Includes logic to check if the observation changed from the previous step.
        """
        self.current_observation = self._process_observation(observation)

        if done:
            return

        if self.step == 0:
            task_info = observation["task_info"]
            if self.gepa_handler:
                system_prompt = self.gepa_handler.construct_system_prompt(environment_info=task_info['environment_info'])
                user_prompt_content = self.gepa_handler.construct_user_prompt(current_observation=self.current_observation["observation"]).replace("Current Observation:", "Initial Observation:")
            else:
                if self.use_non_thinking:
                    system_prompt = SUDOKU_NON_THINKING_SYSTEM_PROMPT.format(
                        environment_info=task_info["environment_info"],
                        concepts=task_info["concepts"],
                    )
                    user_prompt = SUDOKU_NON_THINKING_USER_PROMPT.format(current_observation=self.current_observation["observation"])
                else:
                    system_prompt = SUDOKU_SYSTEM_PROMPT.format(
                    environment_info=task_info["environment_info"],
                    concepts=task_info.get("concepts", ""),
                    )
                    user_prompt = SUDOKU_USER_PROMPT.format(current_observation=self.current_observation["observation"])
                
            if not self.use_concept:
                system_prompt = remove_concepts_from_messages(system_prompt, "## Concepts to achieve the goal", "## Key Guidelines")
            self.messages.append({"role": "system", "content": system_prompt})
            user_prompt_content = user_prompt.replace("Current Observation:", "Initial Observation:")
        else:
            if self.gepa_handler:
                user_prompt_content = self.gepa_handler.construct_user_prompt(current_observation=self.current_observation["observation"])
            else:
                if self.use_non_thinking:
                    user_prompt_content = SUDOKU_NON_THINKING_USER_PROMPT.format(current_observation=self.current_observation["observation"])
                else:
                    user_prompt_content = SUDOKU_USER_PROMPT.format(current_observation=self.current_observation["observation"])

        if self.max_steps - self.step > 0:
            user_prompt_content += f"\nThe maximum number of steps remaining is {self.max_steps - self.step}."
        else:
            user_prompt_content += f"\nThis is the last step. You must output your final answer."
        
        if self.use_non_thinking and self.gepa_handler is None:
            user_prompt_content += "\n\\no_think"

        # Add the user message for the *next* interaction turn
        self.messages.append({"role": "user", "content": user_prompt_content})


    def update_from_model(self, response, **kwargs) -> Action:
        if response is None:
            response = ""

        thought, action = self._parse_model_response(response)

        # If we don't accumulate thinking, we only want to keep the after the last </think> tag
        if self.use_accumulate_thinking:
            # print("Accumulate Thinking")
            self.messages.append({"role": "assistant", "content": response})
        else:
            if self.use_non_thinking:
                self.messages.append({"role": "assistant", "content": f"{action}"})
            else:
                self.messages.append({"role": "assistant", "content": f"{action}"})

        new_step = Step(
            chat_completions=deepcopy(self.chat_completions),
            thought=thought,
            action=action,
            model_response=response,
            observation=self.current_observation,
        )
        self._trajectory.steps.append(new_step)

        self.step += 1

        return Action(action=action)

    def _parse_model_response(self, response: str) -> tuple[str, list[str]]:
        # thought parsing
        thought = response
        if self.gepa_handler:
            thought_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
            if thought_match:
                try:
                    thought = thought_match.group(1).strip()
                except Exception as e:
                    pass
            
            if '<think_summary>' in response and '</think_summary>' in response:
                thought_summary = re.search(r"<think_summary>(.*?)</think_summary>", response, re.DOTALL).group(1).strip()
            elif '<think_summary>' not in response and '</think_summary>' in response and '</think>' in response:
                thought_summary = response.split('</think>')[-1].split('</think_summary>')[0].strip()
            else:
                thought_summary = ""
                
            if thought_summary:
                try:
                    thought += f"\n<think_summary>{thought_summary}</think_summary>"
                except Exception as e:
                    pass
        else:
            if self.use_non_thinking:
                # thought_match = re.search(r"<analysis>(.*?)</analysis>", response, re.DOTALL)
                thought_match = re.search(r"ANALYSIS: (.*?)<answer>", response, re.DOTALL)
            else:
                thought_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)

            if thought_match:
                try:
                    thought = thought_match.group(1).strip()
                except Exception as e:
                    pass
        
        if '<answer>' in response and '</answer>' in response:
            action = response.split('<answer>')[-1].split('</answer>')[0].strip()
        else:
            action = "None"
        # print(f"ACTION: {action}")
        # # action parsing
        # if "</think>" in response:
        #     action = response.split("</answer>")[0].split("<answer>")[-1].strip()
        # else:
        #     if "</answer>" in response:
        #         action = response.split("</answer>")[0].split("<answer>")[-1].strip()
        #     else:
        #         action = "None"
        # answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
        # if answer_match:
        #     action = answer_match.group(1).strip()
        # else:
        #     try:
        #         action = response.split("</think>")[1].strip()
        #     except Exception as e:
        #         action = ""
        
        return thought, action

    @property
    def chat_completions(self) -> list[dict[str, str]]:
        if self.accumulate_history:
            # TODO: Implement Summarize History
            if self.summarize_history:
                return self.messages
            else:
                if self.history_window:
                    MAX_MESSAGES = self.history_window * 2 + 2
                    if len(self.messages) <= MAX_MESSAGES:
                        return self.messages
                    else:
                        total_messages = self.messages[0]   # system message
                        start_idx = - 1 - self.history_window * 2
                        spliced_messages = self.messages[start_idx:]
                        return [total_messages] + spliced_messages
                else:
                    return self.messages
        else:
            if len(self.messages) <= 1:
                return self.messages
            else:
                return [self.messages[0], self.messages[-1]]

    @property
    def trajectory(self) -> Trajectory:
        return self._trajectory

    def reset(self) -> None:
        self._trajectory = Trajectory()
        self.messages = []
        self.step = 0
        # # log init information
        # logger.info(f"SudokuAgent initialized with max_steps: {self.max_steps}, use_accumulate_thinking: {self.accumulate_thinking}, use_accumulate_history: {self.accumulate_history}, history_window: {self.history_window}, use_summarize_history: {self.summarize_history}, use_non_thinking: {self.use_non_thinking}, use_concept: {self.use_concept}")
        if self.gepa_handler:
            logger.info("running GEPA baseline")
            print(self.gepa_handler.prompt_set['static_system'][:100])

    def _process_observation(self, observation: Any) -> str:
        """
        Process observation.
        """
        observation_str = observation.get("board_spaced", "")
        env_message = observation.get("env_message", None)
        if env_message:
            observation_str += f"\n\n- Error message from environment:\n{env_message}"

        return {
            "observation": observation_str,
            "observation_info": {
                "board_ascii": observation.get("board_ascii", ""),
                "board_spaced": observation.get("board_spaced", ""),
                "board_tokens": observation.get("board_tokens", ""),
                "current_turn": observation.get("current_turn", ""),
                "max_turns": observation.get("max_turns", ""),
                "progress": observation.get("progress", ""),
                "correct_cells": observation.get("correct_cells", ""),
                "total_cells": observation.get("total_cells", ""),
            }
        }
    
    # for trajector logger
    def _get_thumbnail(self, observation: dict) -> str:
        board = observation['observation_info']['board_spaced']
        if len(self._trajectory.steps) > 0:
            old_board = self._trajectory.steps[-1].observation['observation_info']['board_spaced']
            board = highlight_changed_text(old_board, board)
        
        lines = board.strip().split('\n')
        n = int(observation['observation_info']['total_cells'] ** 0.5)
        rows = lines[:n]
        board_table = '<table style="border-collapse: collapse;"><tbody>'
        for row in rows:
            board_table += '<tr>'
            for cell in row.split(' ')[:n]:
                board_table += f'<td style="width: 25px; height: 25px; text-align: center; border: 1px solid grey;">{cell}</td>'
            board_table += '</tr>'
        board_table += '</tbody></table>'
        if lines[n:]:
            board_table += '<br>' + '\n'.join(lines[n:])
        
            
        return board_table

    def _get_visible_obs_info(self, observation: dict) -> dict:
        return {
            "progress": observation['observation_info']['progress'],
            "correct_cells": observation['observation_info']['correct_cells'],
            "total_cells": observation['observation_info']['total_cells'],
        }
