import json
from enum import Enum
from typing import Any, List, Literal, Optional, Union, Dict
from pydantic import BaseModel, Field
from src.utils.logsetup import logger


class Role(str, Enum):
    """Message role options"""
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"

ROLE_VALUES = tuple(role.value for role in Role)
ROLE_TYPE = Literal[ROLE_VALUES]  # type: ignore

class Function(BaseModel):
    name: str
    arguments: str

class ToolCall(BaseModel):
    """Represents a tool/function call in a message"""
    id: str
    type: str = "function"
    function: Function

class AgentState(str, Enum):
    """Agent execution states"""
    IDLE = "IDLE"
    RUNNING = "RUNNING"
    FINISHED = "FINISHED"
    ERROR = "ERROR"

class Message(BaseModel):
    """Represents a chat message in the conversation"""
    role: ROLE_TYPE = Field(...)  # type: ignore
    content: Optional[str] = Field(default=None)
    tool_calls: Optional[List[ToolCall]] = Field(default=None)
    name: Optional[str] = Field(default=None)
    tool_call_id: Optional[str] = Field(default=None)
    base64_image: Optional[str] = Field(default=None)
    cache: bool = Field(default=False)
    tag: Optional[str] = Field(default=None)

    def __add__(self, other) -> List["Message"]:
        """Support Message + list or Message + Message operations"""
        if isinstance(other, list):
            return [self] + other
        elif isinstance(other, Message):
            return [self, other]
        else:
            raise TypeError(
                f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
            )

    def __radd__(self, other) -> List["Message"]:
        """Support list + Message operations"""
        if isinstance(other, list):
            return other + [self]
        else:
            raise TypeError(
                f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
            )

    def to_dict(self) -> dict:
        """Convert message to dictionary format"""
        message = {"role": self.role, 'cache': self.cache}
        if self.content is not None:
            message["content"] = self.content
        if self.tool_calls is not None:
            message["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls]
        if self.name is not None:
            message["name"] = self.name
        if self.tool_call_id is not None:
            message["tool_call_id"] = self.tool_call_id
        if self.base64_image is not None:
            message["base64_image"] = self.base64_image
        if self.tag is not None:
            message["tag"] = self.tag
        return message

    @classmethod
    def user_message(
            cls, content: str, base64_image: Optional[str] = None, tag: Optional[str] = None
    ) -> "Message":
        """Create a user message"""
        return cls(role=Role.USER, content=content, base64_image=base64_image, tag=tag)

    @classmethod
    def system_message(cls, content: str, cache=False, base64_image: Optional[str] = None,
                       tag: Optional[str] = None) -> "Message":
        """Create a system message"""
        return cls(role=Role.SYSTEM, content=content, cache=cache, base64_image=base64_image, tag=tag)

    @classmethod
    def assistant_message(
            cls, content: Optional[str] = None, base64_image: Optional[str] = None, tag: Optional[str] = None
    ) -> "Message":
        """Create an assistant message"""
        return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image, tag=tag)

    @classmethod
    def tool_message(
            cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None, tag: Optional[str] = None
    ) -> "Message":
        """Create a tool message"""
        return cls(
            role=Role.TOOL,
            content=content,
            name=name,
            tool_call_id=tool_call_id,
            base64_image=base64_image,
            tag=tag,
        )
    
    @staticmethod
    def convert_tool_calls(tool_calls: List[Any]) -> List[ToolCall]:
        if tool_calls and isinstance(tool_calls[0], ToolCall):
            # If already ToolCall objects, convert to API-suitable dictionary format
            formatted_calls = tool_calls
        else:
            # Try to format tool calls in other ways
            formatted_calls = []
            for call in tool_calls:
                try:
                    if hasattr(call, 'id') and hasattr(call, 'function'):
                        id_, function = call.id, call.function
                    else:
                        id_, function = call.get('id'), call.get('function', {})
                    if hasattr(function, 'model_dump'):
                        # OpenAI format
                        formatted_call = ToolCall(
                            id=id_,
                            type="function",
                            function=Function(
                                name=function.name,
                                arguments=function.arguments
                            )
                        )
                    else:
                        # Other formats
                        formatted_call = ToolCall(
                            id=id_,
                            type="function",
                            function=Function(
                                name=function.get('name'),
                                arguments=function.get('arguments')
                            )
                        )
                    formatted_calls.append(formatted_call)
                except Exception as e:
                    logger.warning(e)
                    continue
        return formatted_calls

    @classmethod
    def from_tool_calls(
            cls,
            tool_calls: List[Any],
            content: Union[str, List[str]] = "",
            base64_image: Optional[str] = None,
            tag: Optional[str] = None,
            **kwargs,
    ):
        """Create an assistant message with tool calls.

        Args:
            tool_calls: List of ToolCall objects
            content: Optional message content
            base64_image: Optional base64 encoded image
            tag: Optional message tag
            **kwargs: Additional keyword arguments
            
        Returns:
            Message: An assistant message with tool calls
        """
        # Create message with tool calls
        return cls(
            role=Role.ASSISTANT,
            content=content,
            tool_calls=cls.convert_tool_calls(tool_calls),
            base64_image=base64_image,
            tag=tag,
            **kwargs,
        )


class Memory(BaseModel):
    messages: List[Message] = Field(default_factory=list)
    max_messages: int = Field(default=100)

    def add_message(self, message: Message) -> None:
        """Add a message to memory"""
        self.messages.append(message)
        # Optional: Implement message limit
        if len(self.messages) > self.max_messages:
            new_messages = []
            for i, message in enumerate(self.messages):
                if i > len(self.messages) - self.max_messages:
                    new_messages.append(message)
                elif message.role == Role.ASSISTANT and not message.tool_calls:
                    new_messages.append(message)
            self.messages = new_messages
            for i in range(len(self.messages)):
                if self.messages[i].role != Role.TOOL:
                    self.messages = self.messages[i:]
                    return

    def add_messages(self, messages: List[Message]) -> None:
        """Add multiple messages to memory"""
        self.messages.extend(messages)

    def clear(self) -> None:
        """Clear all messages"""
        self.messages.clear()

    def get_recent_messages(self, n: int) -> List[Message]:
        """Get n most recent messages"""
        return self.messages[-n:]

    def to_dict_list(self) -> List[dict]:
        """Convert messages to list of dicts"""
        return [msg.to_dict() for msg in self.messages]

    def update_cache(self) -> None:
        """Iterate cache markers, write to the last one each time, and keep the previous one, ensuring there are always only two markers"""
        cache_count = 0
        first_message = None
        for message in self.messages:
            if message.cache:
                cache_count += 1
                if not first_message:
                    first_message = message
        if cache_count > 1:
            first_message.cache = False
        if self.messages:
            self.messages[-1].cache = True

    def clear_tool_messages(self):
        if not self.messages:
            return
        if self.messages[-1].role == Role.TOOL and self.messages[-1].name == 'process_confirm':
            new_messages = []
            for message in self.messages:
                if message.role == Role.USER:
                    stage_prompt = message.content.split('First, you should create folder')[0]
                    new_messages.append(Message.user_message(content=stage_prompt, tag=message.tag))
                elif message.role != Role.TOOL and not message.tool_calls:
                    new_messages.append(message)
            steps_summary = self.messages[-1].content.split('\n', 1)[-1]
            new_messages.append(Message.assistant_message(content=steps_summary))
            self.messages = new_messages

    def count_tool_round(self):
        count = 0
        for message in self.messages:
            if message.role == Role.TOOL:
                count += 1
        return count

    def find_last_tool_message(self, name: str) -> Optional[str]:
        """
        Search backwards for the first tool message matching the specified name and return its content
        Args:
            name: The tool name to search for
        Returns:
            The content of the matching message, or None if not found
        """
        for message in reversed(self.messages):
            if message.role == Role.TOOL and message.name == name:
                return message.content
        return None

    @staticmethod
    def filter_empty_content(messages: List[Message]) -> List[Message]:
        """Filter out messages with role user or assistant and empty content

        Args:
            messages: List of messages to filter

        Returns:
            Filtered list of messages
        """
        if not messages:
            return []

        filtered_messages = []
        for message in messages:
            # Skip messages with user or assistant role and empty content
            if message.role in [Role.USER, Role.ASSISTANT] and not message.content:
                # If assistant has tool_calls, still keep it
                if message.role == Role.ASSISTANT and message.tool_calls:
                    filtered_messages.append(message)
                else:
                    logger.info(f"Filtered out empty content message: role={message.role}")
                    continue
            else:
                filtered_messages.append(message)

        return filtered_messages

    @staticmethod
    def _extract_answer_content(text: str) -> str:
        """Extract the core answer content from a conclusion text using simple string operations"""
        text_lower = text.lower()
        
        # Try to find content after specific phrases
        markers = [
            "the answer is:",
            "the answer is ",
            "based on my research, i can conclude that:",
            "based on my research, i can conclude that ",
            "i can conclude that:",
            "i can conclude that ",
            "the answer is:",
            "the answer is ",
            "the conclusion is:",
            "the conclusion is "
        ]
        
        for marker in markers:
            if marker in text_lower:
                # Find the position after the marker
                start_pos = text_lower.find(marker) + len(marker)
                remaining_text = text[start_pos:].strip()
                
                # Extract until the first period or end of text
                if '.' in remaining_text:
                    answer = remaining_text.split('.')[0].strip()
                else:
                    answer = remaining_text.strip()
                
                # Clean up and limit length
                if answer:
                    # Remove common trailing words/phrases
                    answer = answer.rstrip('.,;!?。，；！？')
                    if len(answer) > 100:
                        answer = answer[:97] + "..."
                    return answer
        
        # Fallback: try to get the last sentence if it looks like an answer
        sentences = text.replace('。', '.').split('.')
        if sentences:
            last_sentence = sentences[-1].strip()
            if len(last_sentence) > 5 and len(last_sentence) < 200:  # reasonable answer length
                return last_sentence.rstrip('.,;!?。，；！？')
        
        # Final fallback: return truncated text
        cleaned = text.strip()
        if len(cleaned) > 100:
            cleaned = cleaned[:97] + "..."
        return cleaned
    
    @staticmethod
    def _extract_answer_from_tool_response(tool_response: str) -> str:
        """Extract answer from tool response, particularly from answer_summarizer results"""
        if not tool_response:
            return ""
        
        # Check if this is answer_summarizer response with final_answer field
        try:
            if tool_response.strip().startswith('{'):
                import json
                response_data = json.loads(tool_response)
                if isinstance(response_data, dict):
                    # Try different possible fields for the answer
                    answer_fields = ['final_answer', 'answer', 'result', 'conclusion']
                    for field in answer_fields:
                        if field in response_data and response_data[field]:
                            return str(response_data[field]).strip()
        except (json.JSONDecodeError, TypeError):
            pass
        
        # Check if response already contains boxed answer format
        if '\\boxed{' in tool_response:
            import re
            boxed_match = re.search(r'\\boxed\{\\text\{([^}]+)\}\}', tool_response)
            if boxed_match:
                return boxed_match.group(1)
            # Try without \text{} wrapper
            boxed_match = re.search(r'\\boxed\{([^}]+)\}', tool_response)
            if boxed_match:
                return boxed_match.group(1)
        
        # Use existing extraction method as fallback
        return Memory._extract_answer_content(tool_response)

    @staticmethod
    def to_sharegpt(messages: List[Message], tools: Optional[List[Dict]] = None) -> List[dict]:
        """convert messages to single-turn sharegpt format with thinking chain XML structure"""
        if not messages:
            return {"conversations": [], "tools": json.dumps(tools, ensure_ascii=False) if tools else "[]"}
        
        # Find user query (first user message)
        user_query = ""
        for message in messages:
            if message.role == Role.USER:
                user_query = message.content
                break
        
        # Build assistant response with thinking chain format
        think_content = ""
        i = 0
        last_tool_response_content = None
        
        while i < len(messages):
            message = messages[i]
            
            if message.role == Role.ASSISTANT:
                # Add observation (thinking content)
                if message.content and message.content.strip():
                    think_content += f"<observation>\n{message.content.strip()}\n</observation>\n"
                
                # Process tool calls
                if message.tool_calls:
                    for tool_call in message.tool_calls:
                        # Format tool call as JSON with tool name and arguments
                        tool_call_json = {
                            "name": tool_call.function.name,
                            "arguments": json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
                        }
                        think_content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
                        
                        # Find corresponding tool response
                        tool_call_id = tool_call.id
                        j = i + 1
                        while j < len(messages):
                            if (messages[j].role == Role.TOOL and 
                                messages[j].tool_call_id == tool_call_id):
                                think_content += f"<tool_response>\n{messages[j].content}\n</tool_response>\n"
                                last_tool_response_content = messages[j].content
                                break
                            j += 1
            i += 1
        
        # Look for final assistant message with conclusion
        final_assistant_msg = None
        for message in reversed(messages):
            if message.role == Role.ASSISTANT and message.content and not message.tool_calls:
                final_assistant_msg = message.content.strip()
                break
        
        # Add final observation if it exists and isn't already included
        if final_assistant_msg and f"<observation>\n{final_assistant_msg}\n</observation>" not in think_content:
            think_content += f"<observation>\n{final_assistant_msg}\n</observation>\n"
        
        # Extract answer from last tool response or final message
        answer_content = ""
        if last_tool_response_content:
            # Try to extract answer from last tool response first
            answer_content = Memory._extract_answer_from_tool_response(last_tool_response_content)
        
        # If no answer found in tool response, try final assistant message
        if not answer_content and final_assistant_msg:
            answer_content = Memory._extract_answer_content(final_assistant_msg)
        
        # Build complete response
        assistant_response = f"<think>\n{think_content}</think>\n"
        
        # Add formatted answer if found
        if answer_content:
            formatted_answer = f"\\boxed{{\\text{{{answer_content}}}}}"
            assistant_response += f"<answer>\n{formatted_answer}\n</answer>"
        
        return {
            "conversations": [
                {
                    "from": "human", 
                    "value": user_query
                },
                {
                    "from": "gpt",
                    "value": assistant_response.strip()
                }
            ],
            "tools": json.dumps(tools, ensure_ascii=False) if tools else "[]"
        }