import json
import time
import uuid
import random
from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional, Dict, List

from loguru import logger

from murmur.agent.base import BaseAgent, is_valid_agent_history_message
from murmur.agent.llm_agent import LLMAgentState, LLMSoloAgent
from murmur.data_model.message import (
    AssistantMessage,
    Message,
    MultiToolMessage,
    ToolMessage,
    UserMessage,
)
from murmur.data_model.simulation import SimulationRun, TerminationReason
from murmur.data_model.tasks import EnvFunctionCall, InitializationData, Task
from murmur.environment.environment import Environment, EnvironmentInfo
from murmur.user.base import BaseUser, is_valid_user_history_message
from murmur.user.user_simulator import DummyUser, UserSimulator, UserState
from murmur.utils.llm_utils import get_cost
from murmur.utils.utils import format_time, get_now
from murmur.orchestrator.orchestrator import Role, DEFAULT_FIRST_AGENT_MESSAGE
from murmur.injection.injection_tasks import InjectionTaskWrapper
from murmur.data_model.tasks import Task

class MultiUserOrchestrator:
    """
    Multi-user orchestrator for the simulation given a task.
    Manages multiple users interacting with a single agent.
    Users can either coordinate on a common task or work on separate tasks.
    """

    def __init__(
        self,
        domain: str,
        agent: BaseAgent,
        users: Dict[str, BaseUser],  # user_id -> BaseUser
        environment: Environment,
        tasks: List[Task] = None,  # For multi-task support
        task: Task = None,  # For backward compatibility
        task_to_users: Dict[str, List[str]] = None,  # task_id -> list of user_ids
        max_steps: int = 100,
        max_errors: int = 10,
        seed: Optional[int] = None,
        solo_mode: bool = False,
        allow_multiple_user_messages: bool = False,  # Allow multiple users to send messages before agent responds
        common_task: bool = True,  # Whether users are working on a common task or separate tasks
        injection_task_obj: Optional[Task] = None,  # Task-based injection task
    ):
        self.domain = domain
        self.agent = agent
        self.users = users  # Dictionary mapping user_id to BaseUser
        self.environment = environment
        
        # Handle backward compatibility and new multi-task structure
        if tasks is not None:
            self.tasks = tasks
            self.task = tasks[0] if tasks else None  # Primary task for compatibility
        else:
            self.tasks = [task] if task else []
            self.task = task
            
        # Use simplified task to users mapping
        self.task_to_users: Dict[str, List[str]] = task_to_users or {}
        
        # Build task_groups from task_to_users for backward compatibility
        self.task_groups: Dict[str, Dict[str, BaseUser]] = {}
        for task_id, user_ids in self.task_to_users.items():
            self.task_groups[task_id] = {user_id: users[user_id] for user_id in user_ids if user_id in users}
        
        self.seed = seed
        self.solo_mode = solo_mode
        self.allow_multiple_user_messages = allow_multiple_user_messages
        self.common_task = common_task
        self.injection_task_obj = injection_task_obj
        
        self.agent_state: LLMAgentState = None
        self.user_states: Dict[str, UserState] = {user_id: UserState(system_messages=[], messages=[]) for user_id in users.keys()}
        self.trajectory: list[Message] = []  # Agent's global trajectory across all tasks
        self.task_trajectories: Dict[str, list[Message]] = {}  # task_id -> user-specific trajectory
        self.max_steps = max_steps
        self.max_errors = max_errors
        self.step_count = 0
        
        # Change done to be a list for each task
        self.done: List[bool] = [False] * len(self.tasks) if self.tasks else [False]
        self.task_done: Dict[str, bool] = {task.id: False for task in self.tasks} if self.tasks else {}
        
        # Track termination reasons for each task
        self.termination_reason: Optional[TerminationReason] = None  # Global termination reason
        self.task_termination_reasons: Dict[str, Optional[TerminationReason]] = {task.id: None for task in self.tasks} if self.tasks else {}
        self.num_errors = 0
        self.from_role: Optional[Role] = None
        self.to_role: Optional[Role] = None
        self.message: Optional[Message] = None
        
        # Multi-user specific state
        self.active_user_id: Optional[str] = None  # Currently active user
        self.current_task_id: Optional[str] = None  # Currently active task
        
        # Injection phase state
        self.injection_phase: bool = False  # Whether we're in injection phase
        self.injection_turn: int = 0  # Current injection turn (0-2)
        self.injection_messages: List[Dict[str, Any]] = []  # Injection messages to send
        
        # Initialize task trajectories for user context isolation
        for task in self.tasks:
            self.task_trajectories[task.id] = []

    def initialize(self):
        """
        Initialize the multi-user orchestrator.
        Similar to single-user orchestrator but handles multiple users.
        """
        initial_state = self.task.initial_state
        initialization_data = (
            initial_state.initialization_data if initial_state is not None else None
        )
        initialization_actions = (
            initial_state.initialization_actions if initial_state is not None else None
        )
        message_history = (
            deepcopy(initial_state.message_history)
            if initial_state is not None and initial_state.message_history is not None
            else []
        )
        for msg in message_history:
            msg.turn_idx = None

        # Add timestamps to the message history
        message_history = self._add_timestamps(message_history)

        if self.solo_mode:
            assert self.environment.solo_mode, "Environment should be in solo mode"
            assert isinstance(self.agent, LLMSoloAgent), (
                "Agent must be a LLMSoloAgent in solo mode"
            )
            assert all(isinstance(user, DummyUser) for user in self.users.values()), (
                "All users must be DummyUser in solo mode"
            )

        # Initialize Environment state
        self._initialize_environment(
            initialization_data=initialization_data,
            initialization_actions=initialization_actions,
            message_history=message_history,
        )

        # Set seeds for the agent and all users
        if self.seed is not None:
            self.agent.set_seed(self.seed)
            for user in self.users.values():
                user.set_seed(self.seed)
                
            random.seed(self.seed)

        # Initialize the agent and user states
        if len(message_history) > 0:
            self.validate_message_history(message_history)
            # Initialize based on message history - simplified for multi-user case
            self._initialize_from_history(message_history)
        else:
            # Initialize fresh conversation
            self.agent_state = self.agent.get_init_state()
            for user_id, user in self.users.items():
                self.user_states[user_id] = user.get_init_state()
            
            if not self.solo_mode:
                first_message = deepcopy(DEFAULT_FIRST_AGENT_MESSAGE)
                first_message.timestamp = get_now()
                self.trajectory = [first_message]
                self.message = first_message
                self.from_role = Role.AGENT
                self.to_role = Role.USER
                
                # Set up injection phase if injection task is provided
                if self.injection_task_obj:
                    self.injection_phase = True
                    self.injection_turn = 0
                    
                    wrapper = InjectionTaskWrapper(self.injection_task_obj)
                    self.injection_messages = wrapper.get_injection_messages()
                    logger.info(f"Starting injection phase with Task-based injection: {self.injection_task_obj.id}")
                    
                    logger.info(f"Starting injection phase with {len(self.injection_messages)} messages")

        self.environment.sync_tools()

    def _initialize_from_history(self, message_history: List[Message]):
        """Initialize states from existing message history."""
        # For simplicity, start fresh but with history in trajectory
        self.agent_state = self.agent.get_init_state(
            message_history=[
                msg for msg in message_history
                if is_valid_agent_history_message(msg)
            ]
        )
        
        for user_id, user in self.users.items():
            # Filter messages for this specific user
            user_messages = [
                msg for msg in message_history
                if is_valid_user_history_message(msg) and 
                (not isinstance(msg, UserMessage) or msg.user_id == user_id or msg.user_id is None)
            ]
            self.user_states[user_id] = user.get_init_state(message_history=user_messages)
        
        self.trajectory = message_history
        
        # Determine the starting state based on last message
        if message_history:
            last_message = message_history[-1]
            if isinstance(last_message, AssistantMessage):
                self.from_role = Role.AGENT
                self.to_role = Role.USER
                self.message = last_message
            elif isinstance(last_message, UserMessage):
                self.from_role = Role.USER
                self.to_role = Role.AGENT
                self.message = last_message
                self.active_user_id = last_message.user_id
            # Handle other message types as needed

    def run(self) -> SimulationRun:
        """
        Run the multi-user simulation.
        """
        start_time = get_now()
        start = time.perf_counter()
        self.initialize()
        
        while not self._all_tasks_done():
            self.step()
            if self.step_count >= self.max_steps:
                self._mark_all_tasks_done(TerminationReason.MAX_STEPS)
                self.termination_reason = TerminationReason.MAX_STEPS
            if self.num_errors >= self.max_errors:
                self._mark_all_tasks_done(TerminationReason.TOO_MANY_ERRORS)
                self.termination_reason = TerminationReason.TOO_MANY_ERRORS
        duration = time.perf_counter() - start
        messages = self.get_trajectory()
        res = get_cost(messages)
        if res is None:
            agent_cost, user_cost = None, None
        else:
            agent_cost, user_cost = res
        
        # Prepare data for multi-task simulation
        task_ids = [task.id for task in self.tasks] if self.tasks else []
        task_termination_reasons_list = [
            self.task_termination_reasons.get(task_id) or self.termination_reason or TerminationReason.USER_STOP
            for task_id in task_ids
        ]
        
        simulation_run = SimulationRun(
            id=str(uuid.uuid4()),
            
            # Multi-task fields
            task_ids=task_ids,
            task_termination_reasons=task_termination_reasons_list,
            task_messages=dict(self.task_trajectories),  # Copy task-specific trajectories
            
            start_time=start_time,
            end_time=get_now(),
            duration=duration,
            reward_info=None,
            user_cost=user_cost,
            agent_cost=agent_cost,
            messages=messages,  # Global message trajectory
            seed=self.seed,
        )
        
        return simulation_run

    def step(self):
        """
        Perform one step of the multi-user simulation.
        Handles multi-user coordination and message routing.
        """
        if self._all_tasks_done():
            raise ValueError("Simulation is done")
        
        logger.debug(
            f"Step {self.step_count}.\nFrom role: {self.from_role}\nTo role: {self.to_role}\nMessage: {self.message}"
        )
        
        # AGENT/ENV -> USERS (possibly multiple users)
        if self.from_role in [Role.AGENT, Role.ENV] and self.to_role == Role.USER:
            self._handle_agent_to_users()
            
        # USER(S) -> AGENT
        elif self.from_role == Role.USER and self.to_role == Role.AGENT:
            self._handle_users_to_agent()
            
        # AGENT/USER -> ENV
        elif self.from_role in [Role.AGENT, Role.USER] and self.to_role == Role.ENV:
            self._handle_to_environment()
            
        # ENV -> AGENT/USER
        elif self.from_role == Role.ENV:
            self._handle_from_environment()
            
        else:
            raise ValueError(
                f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}"
            )
        
        self.step_count += 1
        self.environment.sync_tools()

    def _handle_agent_to_users(self):
        """Handle messages from agent to users."""
        
        # Check if we're in injection phase
        if self.injection_phase:
            self._handle_injection_phase()
        else:
            # For non-collaborative tasks, randomly select user from any non-completed task
            selected_user_id = self._select_user_randomly()
            
            if selected_user_id is None:
                # All tasks are done - this should not happen normally as we check _all_tasks_done in run loop
                self._mark_all_tasks_done(TerminationReason.USER_STOP)
                return
            
            logger.debug(f"Selected user {selected_user_id} to respond")
            self._generate_user_response(selected_user_id)

    def _handle_injection_phase(self):
        """Handle the injection phase where predefined injection messages are sent."""
        if self.injection_turn >= len(self.injection_messages):
            # Injection phase is complete, transition to normal operation
            self.injection_phase = False
            self.injection_turn = 0
            logger.info("Injection phase completed, transitioning to normal user interaction")
            
            # Continue with normal user selection
            selected_user_id = self._select_user_randomly()
            if selected_user_id is None:
                self._mark_all_tasks_done(TerminationReason.USER_STOP)
                return
            logger.debug(f"Selected user {selected_user_id} to respond after injection phase")
            self._generate_user_response(selected_user_id)
            return
        
        # Send the current injection message
        injection_msg = self.injection_messages[self.injection_turn]
        logger.info(f"Sending injection message {self.injection_turn + 1}/{len(self.injection_messages)}: {injection_msg['content']}")
        
        # Create a fake user message with injection content
        user_msg = UserMessage(
            role="user",
            content=injection_msg["content"],
            user_id=injection_msg["user_id"],
            timestamp=get_now()
        )
        
        # Add to trajectory
        self.trajectory.append(user_msg)
        self.message = user_msg
        self.from_role = Role.USER
        self.to_role = Role.AGENT
        self.active_user_id = injection_msg["user_id"]
        
        # Increment injection turn for next iteration
        self.injection_turn += 1

    def _handle_users_to_agent(self):
        """Handle messages from users to agent."""
        # In non-collaborative mode, enforce strict agent-user alternation
        # Always send to agent after user message
        self._generate_agent_response()

    def _handle_to_environment(self):
        """Handle tool calls to environment."""
        if not self.message.is_tool_call():
            raise ValueError("Agent or User should send tool call to environment")
        
        tool_msgs = []
        for tool_call in self.message.tool_calls:
            tool_msg = self.environment.get_response(tool_call)
            tool_msgs.append(tool_msg)
        
        self.trajectory.extend(tool_msgs)
        
        # Add to current task's trajectory
        if self.current_task_id and self.current_task_id in self.task_trajectories:
            self.task_trajectories[self.current_task_id].extend(tool_msgs)
        
        if len(tool_msgs) > 1:
            self.message = MultiToolMessage(
                role="tool",
                tool_messages=tool_msgs,
            )
        else:
            self.message = tool_msgs[0]
            
        # Return to whoever sent the tool call
        self.to_role = self.from_role
        self.from_role = Role.ENV

    def _handle_from_environment(self):
        """Handle responses from environment."""
        if self.to_role == Role.AGENT:
            self._generate_agent_response()
        elif self.to_role == Role.USER:
            # Determine which user should receive the tool response
            user_id = self.active_user_id or list(self.users.keys())[0]
            logger.debug(f"Selected user {user_id} to receive tool response")
            self._generate_user_response(user_id)

    def _generate_user_response(self, user_id: str):
        """Generate a response from a specific user."""
        user = self.users[user_id]
        user_state = self.user_states[user_id]
        
        user_msg, new_user_state = user.generate_next_message(self.message, user_state)
        user_msg.validate()
        
        # Set the user_id for the message
        user_msg.user_id = user_id
        
        # Determine which task this user belongs to
        user_task_id = self._get_task_id_for_user(user_id)
        self.current_task_id = user_task_id
        
        if UserSimulator.is_stop(user_msg):
            # Mark the current task as done
            if user_task_id and user_task_id in self.task_done:
                self.task_done[user_task_id] = True
                self.task_termination_reasons[user_task_id] = TerminationReason.USER_STOP
            
            # Check if all tasks are done
            if self._all_tasks_done():
                self.termination_reason = TerminationReason.USER_STOP
        
        # Add to global agent trajectory and task-specific trajectory
        self.trajectory.append(user_msg)
        if user_task_id in self.task_trajectories:
            self.task_trajectories[user_task_id].append(user_msg)
        
        self.message = user_msg
        self.from_role = Role.USER
        self.active_user_id = user_id
        
        if user_msg.is_tool_call():
            self.to_role = Role.ENV
        else:
            self.to_role = Role.AGENT
            
        # Only share messages with users in the same task group (context isolation)
        if user_task_id in self.task_groups:
            task_user_ids = self.task_groups[user_task_id].keys()
            for itr_user_id in task_user_ids:
                if isinstance(user_msg, MultiToolMessage):
                    self.user_states[itr_user_id].messages.extend(user_msg.tool_messages)
                else:
                    if isinstance(user_msg, AssistantMessage):
                        self.user_states[itr_user_id].messages.append(user_msg)
                    elif itr_user_id == user_msg.user_id:
                        continue
                    else:
                        self.user_states[itr_user_id].messages.append(
                            AssistantMessage(
                                role="assistant",
                                content=user_msg.content,
                                tool_calls=user_msg.tool_calls,
                                cost=user_msg.cost,
                                usage=user_msg.usage,
                                raw_data=user_msg.raw_data,
                            )
                        )

    def _generate_agent_response(self):
        """Generate a response from the agent."""
        
        agent_msg, new_agent_state = self.agent.generate_next_message(
            self.message, self.agent_state
        )
        agent_msg.validate()
        
        if self.agent.is_stop(agent_msg):
            # Mark current task as done with agent stop reason
            if self.current_task_id and self.current_task_id in self.task_done:
                self.task_done[self.current_task_id] = True
                self.task_termination_reasons[self.current_task_id] = TerminationReason.AGENT_STOP
            else:
                # If no current task, mark all as done
                self._mark_all_tasks_done(TerminationReason.AGENT_STOP)
            
            if self._all_tasks_done():
                self.termination_reason = TerminationReason.AGENT_STOP
        
        # Add to global agent trajectory
        self.trajectory.append(agent_msg)
        
        # Add to current task's trajectory
        if self.current_task_id and self.current_task_id in self.task_trajectories:
            self.task_trajectories[self.current_task_id].append(agent_msg)
        
        self.agent_state = new_agent_state
        self.message = agent_msg
        self.from_role = Role.AGENT
        
        if agent_msg.is_tool_call():
            self.to_role = Role.ENV
        else:
            # Only send to users in the current task (context isolation)
            if self.current_task_id and self.current_task_id in self.task_groups:
                task_user_ids = self.task_groups[self.current_task_id].keys()
                for itr_user_id in task_user_ids:
                    if isinstance(agent_msg, MultiToolMessage):
                        self.user_states[itr_user_id].messages.extend(agent_msg.tool_messages)
                    else:
                        self.user_states[itr_user_id].messages.append(agent_msg)
            self.to_role = Role.USER

    def _all_tasks_done(self) -> bool:
        """Check if all tasks are completed."""
        return all(self.task_done.values()) if self.task_done else True
    
    def _mark_all_tasks_done(self, reason: TerminationReason = None):
        """Mark all tasks as completed with the given termination reason."""
        for task_id in self.task_done:
            if not self.task_done[task_id]:  # Only update if not already done
                self.task_done[task_id] = True
                if reason and self.task_termination_reasons[task_id] is None:
                    self.task_termination_reasons[task_id] = reason
    
    def _get_task_id_for_user(self, user_id: str) -> Optional[str]:
        """Get the task ID that a user belongs to."""
        for task_id, user_ids in self.task_to_users.items():
            if user_id in user_ids:
                return task_id
        return None
    
    def _select_user_randomly(self) -> Optional[str]:
        """Randomly select a user from any non-completed task."""
        # Get all user_ids from non-completed tasks
        available_user_ids = []
        
        for task_id, user_ids in self.task_to_users.items():
            if not self.task_done.get(task_id, False):  # Task is not done
                available_user_ids.extend(user_ids)
        
        if not available_user_ids:
            return None  # No available users
        
        # Randomly select from available users
        return random.choice(available_user_ids)

    def get_trajectory(self) -> list[Message]:
        """Get the trajectory of the simulation, sorted by timestamp."""
        messages: list[Message] = sorted(
            deepcopy(self.trajectory),
            key=lambda x: x.timestamp,
        )
        trajectory = []
        for i, msg in enumerate(messages):
            msg = deepcopy(msg)
            msg.turn_idx = i
            trajectory.append(msg)
        return trajectory

    @classmethod
    def validate_message_history(cls, message_history: list[Message]):
        """Validate a message history for multi-user scenarios."""
        # Use the same validation as single-user for now
        # Could be extended for multi-user specific validation
        from murmur.orchestrator.orchestrator import Orchestrator
        Orchestrator.validate_message_history(message_history)

    def _initialize_environment(
        self,
        initialization_data: Optional[InitializationData],
        initialization_actions: Optional[list[EnvFunctionCall]],
        message_history: list[Message],
    ):
        """Initialize the environment."""
        self.environment.set_state(
            initialization_data=initialization_data,
            initialization_actions=initialization_actions,
            message_history=message_history,
        )

    def _add_timestamps(self, message_history: list[Message]) -> list[Message]:
        """Add timestamps to the message history."""
        time_offset = datetime.now() - timedelta(seconds=len(message_history))
        for i, msg in enumerate(message_history):
            msg.timestamp = format_time(time_offset + timedelta(seconds=i))
        return message_history 