import logging
import re
from abc import ABC, abstractmethod
from typing import Any, List
from copy import deepcopy
import json
from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
import os

logger = logging.getLogger(__name__)

SYSTEM_PROMPT_TEMPLATE = """{role}

## Output Requirements
{output_requirements}
"""

INTRO_USER_PROMPT_TEMPLATE = """## Environment Information
{environment_info}
"""

USER_PROMPT_TEMPLATE = """## Current Observation
{current_observation}
"""

OUTRO_USER_PROMPT_TEMPLATE = """Enclose your detailed reasoning process within <think> and </think> tags.
After that, summarize your reasoning (not exceed 8 sentences) following the 'REASON:' tag.
Finally, produce your action after the 'ACTION:' tag, wrapping it in triple backticks (```)."""

class BaseGameAgent(BaseAgent, ABC):
    """
    Base Game Agent class
    """
    role:str = ""
    action_space:str = ""
    goal:str = ""
    output_requirements:str = ""

    def __init__(self, max_steps: int = 30, use_accumulate_thinking: bool = False, history_window: int | None = None, use_multi_turn_format: bool = True, additional_info_path: str = None):
        self.use_accumulate_thinking = use_accumulate_thinking  # controlls whether to accumulate the thinking portion of the response
        self.max_steps = max_steps
        self.history_window = history_window
        self.use_multi_turn_format = use_multi_turn_format  # reasoning models have good performace with single-turn format
        self.additional_info_path = additional_info_path

        # state
        self._trajectory = Trajectory()
        self.step: int = 0
        self.current_observation = None
        self.additional_info = None

        self.reset()
    
    def _make_system_prompt(self, task_info: dict | None = None) -> str:
        role_text = self.role
        if self.additional_info:
            role_text += f"\n{self.additional_info}"

        system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
            role=role_text,
            output_requirements=self.output_requirements,
        )
        return system_prompt
    
    def _make_init_user_prompt(self, task_info: dict) -> str:
        environment_info_str = f"""### Available Actions
{self.action_space}

### Goal
{self.goal}

### Rules
{task_info['rules']}
"""
        user_prompt = INTRO_USER_PROMPT_TEMPLATE.format(
            environment_info=environment_info_str,
        )
        return user_prompt

    def _make_user_prompt(self, observation: Any) -> str:
        user_prompt = USER_PROMPT_TEMPLATE.format(
            current_observation=observation["observation"],
        )
        return user_prompt

    def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
        """
        Updates the agent's internal state after an environment step.
        Includes logic to check if the observation changed from the previous step.
        """
        self.current_observation = self._process_observation(observation)

        if done:
            return

        if self.step == 0:
            if "task_info" in info:
                task_info = info["task_info"]
            else:
                task_info = observation["task_info"]
            self.system_message = self._make_system_prompt(task_info)
            self.init_user_prompt = self._make_init_user_prompt(task_info)

        user_prompt = self._make_user_prompt(self.current_observation)
        if self.max_steps - self.step == 1:
            user_prompt += f"\nThis is the last step to complete the task. Complete it, and the task will be fully finished without anything more required."
        else:
            user_prompt += f"\nThe maximum number of steps remaining is {self.max_steps - self.step}."

        self.messages.append({"role": "user", "content": user_prompt})


    def update_from_model(self, response, **kwargs) -> Action:
        if response is None:
            response = ""

        if os.environ.get("AUTO_INSERT_THINK", "0") == "1":
            if '</think>' not in response:
                response = '</think>\n' + response
            if '<think>' not in response:
                response = '<think>\n' + response

        thought, reason_action = self._parse_model_response(response)

        self.messages.append({"role": "assistant", "content": response})

        new_step = Step(
            chat_completions=deepcopy(self.chat_completions),
            thought=thought,
            action=reason_action,
            model_response=response,
            observation=self.current_observation,
        )
        self._trajectory.steps.append(new_step)

        self.step += 1

        return Action(action=response)

    def _parse_model_response(self, response: str) -> tuple[str, list[str]]:
        if "THOUGHT" in response:            
            thought_match = re.search(r"THOUGHT: (.*?)", response, re.DOTALL)
        else:
            thought_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)

        if thought_match:
            try:
                thought = thought_match.group(1).strip()
            except Exception as e:
                thought = response
        else:
            thought = ""
        
        if "</think>" in response:
            if "ACTION" in response:
                pattern = re.compile(
                    r"""
                    ACTION:\s*
                    (?:                                     # Option A: fenced code block
                        ```(?:[a-zA-Z0-9_-]+)?\s*          # opening fence with optional language
                        (?P<fenced>.*?)                    # content inside fences (non-greedy)
                        \s*```                             # closing fence
                    |
                        (?P<inline>.+)                     # Option B: inline content until end
                    )
                    """,
                    re.DOTALL | re.VERBOSE,
                )

                m = pattern.search(response)
                if m:
                    content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
                    action = "REASON: " + response.split("</think>")[1].split("ACTION")[0].split("REASON:")[-1].strip() + "\n" + "ACTION: ```\n" + content.lower().strip() + "\n```"
                else:
                    action = "REASON: I'm sorry, I have made a parsing error. I'll try again.\nACTION: NULL (Parsing Error)"
            else:
                # for baseline model
                action = "REASON: I'm sorry, I have made a parsing error. I'll try again.\nACTION: NULL (Parsing Error)"
        else:
            action = "REASON: I'm sorry, I have made a parsing error. I'll try again.\nACTION: NULL (Parsing Error)"
        
        return thought, action
    
    def _make_single_turn_format(self, history: list[dict], start_idx: int = 0) -> list[dict]:
        reformatted_message = []
        if history[0]["role"] == "system":
            reformatted_message.append(history[0])
            history_task = history[1:]
        else:
            history_task = history
        # check last item
        last_role = history_task[-1]["role"]
        if last_role == "user":
            history_task = history_task[:-1]
        else:
            history_task = history_task[:-2]

        history_str = "Here is the recent history:\n"
        for i, item in enumerate(history_task):
            if i%2 == 0:
                if item["role"] == "user":
                    history_str += "Step " + str(i//2 + 1 + start_idx) + ":\n"
                    history_str += f"{item['content']}\n"
            else:
                if item["role"] == "assistant":
                    history_str += f"{item['content']}\n\n"

        if last_role == "user":
            history_str += "Current Step:\n"
            history_str += f"{history[-1]['content']}\n"
        else:
            history_str += "Current Step:\n"
            history_str += f"{history[-2]['content']}\n"

        reformatted_message.append(
            {
                "role": "user",
                "content": history_str,
            }
        )

        if last_role == "assistant":
            reformatted_message.append(
                {
                    "role": "assistant",
                    "content": history[-1]['content'],
                }
            )

        return reformatted_message

    def _memory_manager(self, raw_history: list[dict]) -> tuple[list[dict], int]:
        history = []
        last_role = raw_history[-1]["role"]
        if self.history_window:
            if last_role == "user":
                target_len = len(raw_history) - self.history_window * 2 - 1
            else:
                target_len = len(raw_history) - self.history_window * 2
            window_start_idx = max(0, target_len)
            history.extend(raw_history[window_start_idx:])
        else:
            history.extend(raw_history)
        return history

    def _context_manager(self, history: list[dict]) -> list[dict]:
        final_chat_completions = []
        # system prompt
        final_chat_completions.append({"role": "system", "content": self.system_message})

        if history[-1]["role"] == "user":
            last_user_index = len(history) - 1
        else:
            last_user_index = len(history) - 2

        if self.use_multi_turn_format:
            for i, message in enumerate(history):
                content = ""
                if i%2 == 0:
                    if message["role"] == "user":
                        # for the first user message, add the init user prompt
                        if i == 0:
                            content += self.init_user_prompt + "\n\n"

                        content += message["content"]
                        # add outro
                        if i == last_user_index:
                            content += "\n" + OUTRO_USER_PROMPT_TEMPLATE
                        final_chat_completions.append({
                            "role": "user",
                            "content": content,
                        })
                    else:
                        raise ValueError(f"Expected user message at index {i}, but got {message['role']}")
                else:
                    if message["role"] == "assistant":
                        final_chat_completions.append({
                            "role": "assistant",
                            "content": message["content"],
                        })
                    else:
                        raise ValueError(f"Expected assistant message at index {i}, but got {message['role']}")
        else:
            # make single-turn format
            user_content = ""
            assistant_content = None
            for i, message in enumerate(history):
                if i%2 == 0:
                    if message["role"] == "user":
                        if len(history) > 1:
                            if i == 0:
                                user_content += self.init_user_prompt + "\n\nHere is the recent history:\n"
                            # add outro
                            if i == last_user_index:
                                user_content += "\n\nNow you focus on the current step.\n" + message["content"] + "\n" + OUTRO_USER_PROMPT_TEMPLATE
                            else:
                                user_content += message["content"].replace("Current Observation", "Observation") + "\n"
                        else:
                            # initial prompt
                            user_content += self.init_user_prompt + "\n\n" + message["content"] + "\n" + OUTRO_USER_PROMPT_TEMPLATE
                    else:
                        raise ValueError(f"Expected user message at index {i}, but got {message['role']}")
                else:
                    if message["role"] == "assistant":
                        if i == (last_user_index + 1):
                            assistant_content = message["content"]
                        else:
                            # add to user content (history)
                            user_content += "Response:\n" + message["content"] + "\n\n"
                    else:
                        raise ValueError(f"Expected assistant message at index {i}, but got {message['role']}")
            final_chat_completions.append({
                "role": "user",
                "content": user_content,
            })
            if assistant_content:
                final_chat_completions.append({
                    "role": "assistant",
                    "content": assistant_content,
                })
        return final_chat_completions

    @property
    def chat_completions(self) -> list[dict[str, str]]:
        chat_completion_list = []
        if self.use_accumulate_thinking:
            chat_completion_list = deepcopy(self.messages)
        else:
            for i, msg in enumerate(self.messages):
                if msg["role"] == "assistant" and i < (len(self.messages) - 1):
                    _, reason_action = self._parse_model_response(msg["content"])
                    chat_completion_list.append({
                        "role": msg["role"],
                        "content": reason_action,
                    })
                else:
                    # for the last assistant message, we don't need to parse the model response
                    chat_completion_list.append(msg)

        # reformat based on history window
        history = self._memory_manager(chat_completion_list)
        final_chat_completions = self._context_manager(history)

        return final_chat_completions

    @property
    def trajectory(self) -> Trajectory:
        return self._trajectory

    def reset(self) -> None:
        self._trajectory = Trajectory()
        self.messages = []
        self.step = 0
        self.additional_info = None
        if self.additional_info_path:
            try:
                with open(self.additional_info_path, "r") as f:
                    self.additional_info = json.load(f)
            except Exception as e:
                logger.error(f"Error reading additional info file: {e}")
                self.additional_info = None
    
    @abstractmethod
    def _process_observation(self, observation: Any) -> str:
        """
        Process observation.
        """
        return {
            "observation": observation,
        }