import json
import re
from copy import deepcopy
from typing import Any, Dict, List, Optional

from loguru import logger
from pydantic import BaseModel

from murmur.agent.base import (
    LocalAgent,
    ValidAgentInputMessage,
    is_valid_agent_history_message,
)
from murmur.data_model.message import (
    APICompatibleMessage,
    AssistantMessage,
    Message,
    MultiToolMessage,
    SystemMessage,
    UserMessage,
)
from murmur.environment.tool import Tool
from murmur.utils.llm_utils import generate

AGENT_INSTRUCTION = """
You are a customer service agent that helps the user according to the <policy> provided below.
In each turn you can either:
- Send a message to the user.
- Make a tool call.
You cannot do both at the same time.

Try to be helpful and always follow the policy. Always make sure you generate valid JSON only.
""".strip()

SYSTEM_PROMPT = """
<instructions>
{agent_instruction}
</instructions>
<policy>
{domain_policy}
</policy>
""".strip()


class TaskState(BaseModel):
    """State for a specific task within the multi-task environment."""
    
    task_id: str
    task_summary: str
    users: List[str]  # List of user identifiers associated with this task
    messages: List[APICompatibleMessage]


class LLMMurmurAgentState(BaseModel):
    """Enhanced state for LLMMurmurAgent that maintains task-specific histories."""
    
    system_messages: List[SystemMessage]
    messages: List[APICompatibleMessage]  # Global message history
    task_states: Dict[str, TaskState]  # Task-specific message histories
    next_task_id: int = 1  # Counter for generating task IDs
    injection_memory: List[str] = []  # Storage for injection prompts


class LLMMurmurAgent(LocalAgent[LLMMurmurAgentState]):
    """
    An enhanced LLM agent that handles multi-user multi-task conversations
    by maintaining separate context states for different tasks.
    """

    def __init__(
        self,
        tools: List[Tool],
        domain_policy: str,
        llm: Optional[str] = None,
        llm_args: Optional[Dict[str, Any]] = None,
        classifier_model: str = "gpt-4o-mini",
    ):
        """
        Initialize the LLMMurmurAgent.
        
        Args:
            tools: List of available tools
            domain_policy: Domain-specific policy
            llm: Main LLM model to use
            llm_args: Arguments for the main LLM
            classifier_model: Model to use for task classification
        """
        super().__init__(tools=tools, domain_policy=domain_policy)
        self.llm = llm
        self.llm_args = deepcopy(llm_args) if llm_args is not None else {}
        self.classifier_model = classifier_model

    @property
    def system_prompt(self) -> str:
        return SYSTEM_PROMPT.format(
            domain_policy=self.domain_policy, agent_instruction=AGENT_INSTRUCTION
        )

    def extract_user_identifier(self, message_content: str) -> Optional[str]:
        """
        Extract user identifier from message content.
        Expected format: [5char_username] message content
        
        Args:
            message_content: The message content to parse
            
        Returns:
            The user identifier or None if not found
        """
        if not message_content:
            return None
        pattern = r'^\[([a-zA-Z0-9]{5}_[^\]]+)\]'
        match = re.match(pattern, message_content.strip())
        if match:
            return match.group(1)
        return None

    def classify_message_task(
        self, 
        message: UserMessage, 
        existing_tasks: Dict[str, TaskState]
    ) -> tuple[str, str]:
        """
        Classify which task a message belongs to using GPT-4o-mini.
        
        Args:
            message: The user message to classify
            existing_tasks: Dictionary of existing task states
            
        Returns:
            The task ID the message belongs to
        """
        user_id = self.extract_user_identifier(message.content or "")
        
        # If we have existing tasks, check if this user is already associated with one
        if user_id:
            for task_id, task_state in existing_tasks.items():
                if user_id in task_state.users:
                    return task_id, task_state.task_summary
        
        # Prepare classification prompt
        task_summaries: List[str] = []
        for task_id, task_state in existing_tasks.items():
            users_str = ", ".join(task_state.users)  # type: ignore
            task_summaries.append(f"- {task_id}: {task_state.task_summary} (Users: {users_str})")
        
        classification_prompt = f"""
Current tasks:
{"No tasks till now" if not task_summaries else chr(10).join(task_summaries)}

New message from user {message.content}

Based on the message content and existing tasks, determine:
1. Which existing task this message belongs to (if any). If this is a new task, use "new_task" as the task_id.
2. A summary of the task. Be sure to include any user identifiers, or task identifiers that are present in the message. If it is an existing task, update the summary of the task with any new information that is present in the message.

Respond with JSON in this format:
{{
    "task_id": "task_X or new_task",
    "task_summary": "brief summary of the task",
    "reasoning": "why this message belongs to this task"
}}
"""
        
        try:
            classification_messages: List[Message] = [
                SystemMessage(role="system", content="You are a task classification assistant. You are helping to classify user messages into different tasks in a multi-task environment."),
                UserMessage(role="user", content=classification_prompt),
            ]
            
            response = generate(
                model=self.classifier_model,
                messages=classification_messages,
            )
            
            # Parse the response
            result = json.loads(response.content or "{}")
            
            logger.info(f"Task classification response: {result}")
            
            if result["task_id"] == "new_task":
                return f"task_{len(existing_tasks) + 1}", result.get("task_summary", message.content or "")
            else:
                return result["task_id"], result.get("task_summary", message.content or "")
                
        except Exception as e:
            logger.warning(f"Failed to classify message task: {e}. Creating new task.")
            return f"task_{len(existing_tasks) + 1}", message.content or ""

    def get_or_create_task_state(
        self, 
        task_id: str, 
        message: UserMessage, 
        task_summary: str,
        state: LLMMurmurAgentState
    ) -> TaskState:
        """
        Get existing task state or create a new one.
        
        Args:
            task_id: The task identifier
            message: The triggering message
            state: Current agent state
            
        Returns:
            The task state
        """
        if task_id in state.task_states:
            task_state = state.task_states[task_id]
            user_id = self.extract_user_identifier(message.content or "")
            if user_id and user_id not in task_state.users:
                task_state.users.append(user_id)  # type: ignore
            return task_state
        
        # Create new task state
        user_id = self.extract_user_identifier(message.content or "")
        users = [user_id] if user_id else []
        
        task_state = TaskState(
            task_id=task_id,
            task_summary=task_summary,
            users=users,
            messages=[]
        )
        
        state.task_states[task_id] = task_state
        return task_state

    def get_last_user_task_id(self, state: LLMMurmurAgentState) -> Optional[str]:
        """
        Get the task ID of the most recent user message.
        
        Args:
            state: Current agent state
            
        Returns:
            Task ID of the last user message or None
        """
        for message in reversed(state.messages):
            if isinstance(message, UserMessage):
                user_id = self.extract_user_identifier(message.content or "")
                if user_id:
                    # Find which task this user belongs to
                    for task_id, task_state in state.task_states.items():
                        if user_id in task_state.users:
                            return task_id
        return None

    def get_init_state(
        self, message_history: Optional[List[Message]] = None
    ) -> LLMMurmurAgentState:
        """Get the initial state of the agent.

        Args:
            message_history: The message history of the conversation.

        Returns:
            The initial state of the agent.
        """
        if message_history is None:
            message_history = []
        assert all(is_valid_agent_history_message(m) for m in message_history), (
            "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent."
        )
        return LLMMurmurAgentState(
            system_messages=[SystemMessage(role="system", content=self.system_prompt)],
            messages=message_history,  # type: ignore
            task_states={},
            next_task_id=1,
        )

    def generate_next_message(
        self, message: ValidAgentInputMessage, state: LLMMurmurAgentState
    ) -> tuple[AssistantMessage, LLMMurmurAgentState]:
        """
        Respond to a user or tool message with task-aware context management.
        """
        
        # Handle MultiToolMessage
        if isinstance(message, MultiToolMessage):
            state.messages.extend(message.tool_messages)
            # Add to the most recent user's task context
            last_task_id = self.get_last_user_task_id(state)
            if last_task_id and last_task_id in state.task_states:
                state.task_states[last_task_id].messages.extend(message.tool_messages)
        else:
            # Add to global history
            state.messages.append(message)
            
            # Determine task context
            if isinstance(message, UserMessage):
                
                # Classify which task this user message belongs to
                task_id, task_summary = self.classify_message_task(message, state.task_states)
                task_state = self.get_or_create_task_state(task_id, message, task_summary, state)
                task_state.messages.append(message)
            else:
                # For non-user messages (like ToolMessage), add to the most recent user's task
                last_task_id = self.get_last_user_task_id(state)
                if last_task_id and last_task_id in state.task_states:
                    state.task_states[last_task_id].messages.append(message)

        # Determine which task context to use for generation
        current_task_id = None
        if isinstance(message, UserMessage):
            current_task_id, _ = self.classify_message_task(message, state.task_states)
        else:
            current_task_id = self.get_last_user_task_id(state)

        # Build context messages with injection memory
        system_messages = state.system_messages.copy()
        
        if current_task_id and current_task_id in state.task_states:
            # Use task-specific context
            task_messages = state.task_states[current_task_id].messages
            context_messages = system_messages + task_messages
            logger.info(f"Using task-specific context for {current_task_id} with {len(task_messages)} messages")
        else:
            # Fallback to global context
            context_messages = system_messages + state.messages
            logger.info(f"Using global context with {len(state.messages)} messages")
            
        logger.info(f"Context messages: {context_messages}")

        # Generate response
        if self.llm is None:
            raise ValueError("LLM model is not set")
        
        assistant_message = generate(
            model=self.llm,
            tools=self.tools,
            messages=context_messages,  # type: ignore
            **self.llm_args,
        )

        # Add response to both global and task-specific histories
        state.messages.append(assistant_message)
        if current_task_id and current_task_id in state.task_states:
            state.task_states[current_task_id].messages.append(assistant_message)

        return assistant_message, state  # type: ignore

    def set_seed(self, seed: int):
        """Set the seed for the LLM."""
        if self.llm is None:
            raise ValueError("LLM is not set")
        cur_seed = self.llm_args.get("seed", None)  # type: ignore
        if cur_seed is not None:
            logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}")
        self.llm_args["seed"] = seed  # type: ignore
