from enum import Enum
from typing import Optional

from murmur.data_model.message import Message
from murmur.data_model.simulation import RewardInfo, SimulationRun, TerminationReason
from murmur.data_model.tasks import RewardType, Task
from murmur.evaluator.evaluator_action import ActionEvaluator
from murmur.evaluator.evaluator_communicate import CommunicateEvaluator
from murmur.evaluator.evaluator_env import EnvironmentEvaluator
from murmur.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
from murmur.registry import registry


class EvaluationType(str, Enum):
    ENV = "env"
    NL_ASSERTIONS = "nl_assertions"
    COMMUNICATE = "communicate"
    ACTION = "action"
    ALL = "all"


def evaluate_simulation(
    simulation: SimulationRun,
    task: Task,
    evaluation_type: EvaluationType,
    solo_mode: bool,
    domain: str,
    full_history: list[Message],
) -> RewardInfo:
    """
    Evaluate the simulation based on the evaluation type.
    """
    if simulation.termination_reason in {
        TerminationReason.TOO_MANY_ERRORS,
        TerminationReason.MAX_STEPS,
    }:
        return RewardInfo(
            reward=0.0,
            info={
                "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}"
            },
        )
    if task.evaluation_criteria is None:
        return RewardInfo(
            reward=1.0,
            info={"note": "No evaluation criteria"},
        )
    if evaluation_type == EvaluationType.ENV:
        reward_info = EnvironmentEvaluator.calculate_reward(
            environment_constructor=registry.get_env_constructor(domain),
            task=task,
            full_trajectory=simulation.messages,
            solo_mode=solo_mode,
        )
    elif evaluation_type == EvaluationType.NL_ASSERTIONS:
        reward_info = NLAssertionsEvaluator.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
        )
    elif evaluation_type == EvaluationType.COMMUNICATE:
        reward_info = CommunicateEvaluator.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
        )
    elif evaluation_type == EvaluationType.ACTION:
        reward_info = ActionEvaluator.calculate_reward(
            task=task,
            full_trajectory=full_history,
        )
    elif evaluation_type == EvaluationType.ALL:
        if domain == "airline":
            env_reward_info = EnvironmentEvaluator.calculate_reward(
                environment_constructor=registry.get_env_constructor(domain),
                task=task,
                full_trajectory=simulation.messages,
                solo_mode=solo_mode,
            )
        action_reward_info = ActionEvaluator.calculate_reward(
            task=task,
            full_trajectory=full_history,
        )
        communicate_reward_info = CommunicateEvaluator.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
        )
        nl_reward_info = NLAssertionsEvaluator.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
        )

        ## Combine all the rewards.
        reward = 1.0
        env_bases = {RewardType.DB, RewardType.ENV_ASSERTION}
        action_bases = {RewardType.ACTION}
        nl_bases = {RewardType.NL_ASSERTION}
        comm_bases = {RewardType.COMMUNICATE}
        task_reward_basis = set(task.evaluation_criteria.reward_basis)

        reward_breakdown = {}
        if domain == "airline" and (task_reward_basis & env_bases):
            if env_reward_info.reward_breakdown is not None:
                reward_breakdown.update(env_reward_info.reward_breakdown)
            reward *= env_reward_info.reward
        if task_reward_basis & action_bases:
            if action_reward_info.reward_breakdown is not None:
                reward_breakdown.update(action_reward_info.reward_breakdown)
            reward *= action_reward_info.reward
        if task_reward_basis & nl_bases:
            if nl_reward_info.reward_breakdown is not None:
                reward_breakdown.update(nl_reward_info.reward_breakdown)
            reward *= nl_reward_info.reward
        if task_reward_basis & comm_bases:
            if communicate_reward_info.reward_breakdown is not None:
                reward_breakdown.update(communicate_reward_info.reward_breakdown)
            reward *= communicate_reward_info.reward

        reward_info = RewardInfo(
            reward=reward,
            db_check=None if domain != "airline" else env_reward_info.db_check,
            env_assertions=None if domain != "airline" else env_reward_info.env_assertions,
            action_checks=action_reward_info.action_checks,
            nl_assertions=nl_reward_info.nl_assertions,
            communicate_checks=communicate_reward_info.communicate_checks,
            reward_basis=task.evaluation_criteria.reward_basis + ([RewardType.ACTION] if domain != "airline" else []),
            reward_breakdown=reward_breakdown,
            info={
                "env": None if domain != "airline" else env_reward_info.info,
                "nl": nl_reward_info.info,
                "communicate": communicate_reward_info.info,
                "action": action_reward_info.info,
            },
        )
    else:
        raise ValueError(f"Unknown evaluation type: {evaluation_type}")
    return reward_info


def evaluate_multi_task_simulation(
    simulation: SimulationRun,
    tasks: list[Task],
    task_termination_reasons: dict[str, str],
    evaluation_type: EvaluationType,
    solo_mode: bool,
    domain: str,
    injection_task: Optional[Task] = None,
) -> list[dict]:
    """
    Evaluate a multi-task simulation where each task should be evaluated individually.
    
    Args:
        simulation: The simulation run containing all messages
        tasks: List of tasks that were run
        evaluation_type: Type of evaluation to perform
        solo_mode: Whether running in solo mode
        domain: The domain being evaluated
        
    Returns:
        List of dictionaries containing task_id, reward_info, and termination_reason for each task
    """
    task_rewards = []
    
    # Use the new multi-task structure if available
    if hasattr(simulation, 'task_messages') and simulation.task_messages:
        # Use task-specific messages from the new structure
        for i, task in enumerate(tasks):
            task_messages = simulation.task_messages.get(task.id, [])
            
            # Get termination reason for this task
            if hasattr(simulation, 'task_termination_reasons') and len(simulation.task_termination_reasons) > i:
                task_termination_reason = simulation.task_termination_reasons[i]
            else:
                task_termination_reason = simulation.termination_reason or TerminationReason.USER_STOP
            
            # Create a task-specific simulation run  
            task_simulation = SimulationRun(
                id=f"{simulation.id}_{task.id}",
                
                # New multi-task fields
                task_ids=[task.id],
                task_termination_reasons=[task_termination_reason],
                task_messages={task.id: task_messages},
                
                # Backward compatibility fields
                task_id=task.id,
                termination_reason=task_termination_reason,
                
                start_time=simulation.start_time,
                end_time=simulation.end_time,
                duration=simulation.duration,
                reward_info=None,
                user_cost=simulation.user_cost,
                agent_cost=simulation.agent_cost,
                messages=task_messages,  # Use task-specific messages
                seed=simulation.seed,
            )
            
            # Evaluate this specific task
            reward_info = evaluate_simulation(
                simulation=task_simulation,
                task=task,
                evaluation_type=evaluation_type,
                solo_mode=solo_mode,
                domain=domain,
                full_history=simulation.messages,
            )
            
            task_rewards.append({
                'task_id': task.id,
                'reward_info': reward_info,
                'termination_reason': task_termination_reason
            })
            
            if injection_task:
                # Evaluate injection task against the current task message history
                injection_reward_info = evaluate_simulation(
                    simulation=task_simulation,
                    task=injection_task,
                    evaluation_type=evaluation_type,
                    solo_mode=solo_mode,
                    domain=domain,
                    full_history=task_simulation.messages,
                )
                
                task_rewards.append({
                    'task_id': f"{injection_task.id}_{task.id}",
                    'reward_info': injection_reward_info,
                    'termination_reason': simulation.termination_reason or TerminationReason.USER_STOP,
                    'is_injection_task': True,  # Mark this as an injection task
                })
    else:
        # Fallback to old behavior: filter messages from global messages
        print("Simulation messages: ", simulation.messages)
        for task in tasks:
            # Filter messages relevant to this task
            task_messages = filter_messages_for_task(simulation.messages, task)
            
            # Get termination reason for this task
            task_termination_reason = task_termination_reasons.get(task.id, simulation.termination_reason or TerminationReason.USER_STOP) if task_termination_reasons else (simulation.termination_reason or TerminationReason.USER_STOP)
            
            # Create a task-specific simulation run
            task_simulation = SimulationRun(
                id=f"{simulation.id}_{task.id}",
                
                # New multi-task fields
                task_ids=[task.id],
                task_termination_reasons=[task_termination_reason],
                task_messages={task.id: task_messages},
                
                # Backward compatibility fields
                task_id=task.id,
                termination_reason=task_termination_reason,
                
                start_time=simulation.start_time,
                end_time=simulation.end_time,
                duration=simulation.duration,
                reward_info=None,
                user_cost=simulation.user_cost,
                agent_cost=simulation.agent_cost,
                messages=task_messages,
                seed=simulation.seed,
            )
            
            # Evaluate this specific task
            reward_info = evaluate_simulation(
                simulation=task_simulation,
                task=task,
                evaluation_type=evaluation_type,
                solo_mode=solo_mode,
                domain=domain,
                full_history=simulation.messages,
            )
            
            task_rewards.append({
                'task_id': task.id,
                'reward_info': reward_info,
                'termination_reason': task_termination_reason
            })

            if injection_task:
                # Evaluate injection task against the full message history
                injection_reward_info = evaluate_simulation(
                    simulation=simulation,
                    task=injection_task,
                    evaluation_type=evaluation_type,
                    solo_mode=solo_mode,
                    domain=domain,
                    full_history=task_simulation.messages,
                )
                
                task_rewards.append({
                    'task_id': injection_task.id,
                    'reward_info': injection_reward_info,
                    'termination_reason': simulation.termination_reason or TerminationReason.USER_STOP,
                    'is_injection_task': True  # Mark this as an injection task
                })
    
    return task_rewards


def filter_messages_for_task(messages: list, task: Task) -> list:
    """
    Filter messages to only include those relevant to a specific task.
    This includes messages from users belonging to the task and agent responses to them.
    """
    task_messages = []
    
    for message in messages:
        # Include agent messages (they see all tasks)
        if hasattr(message, 'role') and message.role == 'assistant':
            task_messages.append(message)
        # Include tool messages
        elif hasattr(message, 'role') and message.role == 'tool':
            task_messages.append(message)
        # Include user messages that belong to this task (user_id contains task_id)
        elif hasattr(message, 'user_id') and message.user_id and task.id in message.user_id:
            task_messages.append(message)
        # Include messages with task_id attribute matching
        elif hasattr(message, 'task_id') and message.task_id == task.id:
            task_messages.append(message)
        # Include system messages
        elif hasattr(message, 'role') and message.role == 'system':
            task_messages.append(message)
    
    return task_messages
