from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Optional

import pandas as pd
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from murmur.config import (
    DEFAULT_LLM_AGENT,
    DEFAULT_LLM_ARGS_AGENT,
    DEFAULT_LLM_ARGS_USER,
    DEFAULT_LLM_USER,
    DEFAULT_LOG_LEVEL,
    DEFAULT_MAX_CONCURRENCY,
    DEFAULT_MAX_ERRORS,
    DEFAULT_MAX_STEPS,
    DEFAULT_NUM_TRIALS,
    DEFAULT_SAVE_TO,
    DEFAULT_SEED,
)
from murmur.data_model.message import Message
from murmur.data_model.tasks import Action, EnvAssertion, RewardType, Task
from murmur.environment.environment import EnvironmentInfo
from murmur.utils.utils import get_now


class RunConfig(BaseModel):
    domain: Annotated[
        str,
        Field(
            description="The domain to run the simulation on",
            default="airline",
        ),
    ]
    task_set_name: Annotated[
        Optional[str],
        Field(
            description="The task set to run the simulation on. If not provided, will load default task set for the domain.",
            default=None,
        ),
    ]
    task_ids: Annotated[
        Optional[list[str]],
        Field(
            description="The task IDs to run the simulation on",
            default=None,
        ),
    ]
    num_tasks: Annotated[
        Optional[int],
        Field(
            description="The number of tasks to run the simulation on",
            default=None,
        ),
    ]
    is_remote: Annotated[
        bool,
        Field(
            description="Whether to run the simulation remotely",
            default=False,
        ),
    ]
    agent: Annotated[
        str,
        Field(
            description="The type of agent to run the simulation on",
            default="llm_agent",
        ),
    ]
    llm_agent: Annotated[
        str,
        Field(
            description="The model to use for the agent",
            default=DEFAULT_LLM_AGENT,
        ),
    ]
    llm_args_agent: Annotated[
        dict,
        Field(
            description="The arguments to pass to the LLM for the agent",
            default_factory=lambda: deepcopy(DEFAULT_LLM_ARGS_AGENT),
        ),
    ]
    user: Annotated[
        str,
        Field(
            description="The type of user to run the simulation on",
            default="user_simulator",
        ),
    ]
    llm_user: Annotated[
        str,
        Field(
            description="The model to use for the user",
            default=DEFAULT_LLM_USER,
        ),
    ]
    llm_args_user: Annotated[
        dict,
        Field(
            description="The arguments to pass to the LLM for the user",
            default_factory=lambda: deepcopy(DEFAULT_LLM_ARGS_USER),
        ),
    ]
    num_trials: Annotated[
        int,
        Field(
            description="The number of trials to run the simulation on",
            default=DEFAULT_NUM_TRIALS,
        ),
    ]
    max_steps: Annotated[
        int,
        Field(
            description="The maximum number of steps to run the simulation",
            default=DEFAULT_MAX_STEPS,
        ),
    ]
    max_errors: Annotated[
        int,
        Field(
            description="The maximum number of tool errors allowed in a row in the simulation",
            default=DEFAULT_MAX_ERRORS,
        ),
    ]
    save_to: Annotated[
        Optional[str],
        Field(
            description="The path to json file where to save the simulation results",
            default=DEFAULT_SAVE_TO,
        ),
    ]
    max_concurrency: Annotated[
        int,
        Field(
            description="The maximum number of concurrent simulations to run",
            default=DEFAULT_MAX_CONCURRENCY,
        ),
    ]
    seed: Annotated[
        Optional[int],
        Field(
            description="The seed to use for the simulation",
            default=DEFAULT_SEED,
        ),
    ]
    log_level: Annotated[
        Optional[str],
        Field(
            description="The log level to use for the simulation",
            default=DEFAULT_LOG_LEVEL,
        ),
    ]
    multi_user: Annotated[
        bool,
        Field(
            description="Whether to enable multi-user mode",
            default=False,
        ),
    ]
    allow_multiple_user_messages: Annotated[
        bool,
        Field(
            description="Allow multiple users to send messages before agent responds",
            default=False,
        ),
    ]
    multi_task: Annotated[
        bool,
        Field(
            description="Whether to enable multi-task mode",
            default=False,
        ),
    ]
    max_concurrent_tasks: Annotated[
        int,
        Field(
            description="Maximum number of tasks to run simultaneously in multi-task mode",
            default=3,
        ),
    ]
    task_selection_strategy: Annotated[
        str,
        Field(
            description="Strategy for selecting tasks in multi-task mode",
            default="random",
        ),
    ]
    injection_task: Annotated[
        Optional[str],
        Field(
            description="Name of the injection task to run before user tasks",
            default=None,
        ),
    ]

    def validate(self) -> None:
        """
        Validate the run config
        """
        pass


class NLAssertionCheck(BaseModel):
    """
    A natural language assertion.
    """

    nl_assertion: str
    met: bool
    justification: str


class CommunicateCheck(BaseModel):
    """
    A communication check.
    """

    info: str
    met: bool
    justification: str


class DBCheck(BaseModel):
    """
    A database check.
    """

    db_match: bool
    db_reward: float


class ActionCheck(BaseModel):
    """
    An action check.
    """

    action: Action
    action_match: bool
    action_reward: float


class EnvAssertionCheck(BaseModel):
    """
    An environment assertion check.
    """

    env_assertion: EnvAssertion
    met: bool
    reward: float


class RewardInfo(BaseModel):
    """
    The reward received by the agent.
    """

    reward: Annotated[float, Field(description="The reward received by the agent.")]
    db_check: Annotated[
        Optional[DBCheck], Field(description="The database check.", default=None)
    ]
    env_assertions: Annotated[
        Optional[list[EnvAssertionCheck]],
        Field(description="The environment assertions.", default=None),
    ]
    action_checks: Annotated[
        Optional[list[ActionCheck]],
        Field(description="The action checks.", default=None),
    ]
    nl_assertions: Annotated[
        Optional[list[NLAssertionCheck]],
        Field(description="The natural language assertions.", default=None),
    ]
    communicate_checks: Annotated[
        Optional[list[CommunicateCheck]],
        Field(
            description="Checks that the agent communicated the required information.",
            default=None,
        ),
    ]
    reward_basis: Annotated[
        Optional[list[RewardType]],
        Field(
            description="The basis of the reward. Fields that are used to calculate the reward.",
            default_factory=lambda: [RewardType.DB, RewardType.ENV_ASSERTION, RewardType.NL_ASSERTION, RewardType.COMMUNICATE],
        ),
    ]
    reward_breakdown: Annotated[
        Optional[dict[RewardType, float]],
        Field(
            description="The breakdown of the reward.",
            default=None,
        ),
    ]
    info: Annotated[
        Optional[dict],
        Field(description="Additional information about the reward.", default=None),
    ]


class AgentInfo(BaseModel):
    """
    Agent information.
    """

    implementation: str = Field(description="The type of agent.")
    llm: Optional[str] = Field(description="The LLM used by the agent.", default=None)
    llm_args: Optional[dict] = Field(
        description="The arguments to pass to the LLM for the agent.", default=None
    )


class UserInfo(BaseModel):
    """
    User information.
    """

    implementation: str = Field(description="The type of user.")
    llm: Optional[str] = Field(description="The LLM used by the user.", default=None)
    llm_args: Optional[dict] = Field(
        description="The arguments to pass to the LLM for the user.", default=None
    )
    global_simulation_guidelines: Optional[str] = Field(
        description="The global simulation guidelines for the user.", default=None
    )


class Info(BaseModel):
    """Information about the simulator."""

    git_commit: str = Field(description="The git commit hash.")
    num_trials: int = Field(description="The number of trials.")
    max_steps: int = Field(description="The maximum number of steps.")
    max_errors: int = Field(description="The maximum number of errors.")
    user_info: UserInfo = Field(description="User information.")
    agent_info: AgentInfo = Field(description="Agent information.")
    environment_info: EnvironmentInfo = Field(description="Environment information.")
    seed: Optional[int] = Field(
        description="The seed used for the simulation.", default=None
    )


class TerminationReason(str, Enum):
    USER_STOP = "user_stop"
    AGENT_STOP = "agent_stop"
    MAX_STEPS = "max_steps"
    TOO_MANY_ERRORS = "too_many_errors"


class SimulationRun(BaseModel):
    """
    Simulation run for the given task(s).
    """

    id: str = Field(description="The unique identifier for the simulation run.")
    
    # Multi-task support: store lists of task_ids and termination reasons
    task_ids: list[str] = Field(description="The list of task identifiers for multi-task simulations.")
    task_termination_reasons: list[TerminationReason] = Field(
        description="The termination reasons for each task (corresponds to task_ids order)."
    )
    
    # Backward compatibility: keep single task_id and termination_reason 
    task_id: Optional[str] = Field(description="The unique identifier for the task (backward compatibility).", default=None)
    termination_reason: Optional[TerminationReason] = Field(
        description="The reason for the termination of the simulation (backward compatibility).", default=None
    )
    
    timestamp: str = Field(
        description="The timestamp of the simulation.", default_factory=get_now
    )
    start_time: str = Field(description="The start time of the simulation.")
    end_time: str = Field(description="The end time of the simulation.")
    duration: float = Field(description="The duration of the simulation.")
    
    agent_cost: Optional[float] = Field(
        description="The cost of the agent.", default=None
    )
    user_cost: Optional[float] = Field(
        description="The cost of the user.", default=None
    )
    reward_info: Optional[RewardInfo] = Field(
        description="The reward received by the agent.", default=None
    )
    task_rewards: Optional[list[dict[str, object]]] = Field(
        description="The reward received by the agent for each task.", default=None
    )
    
    # Multi-task support: separate trajectories per task
    task_messages: dict[str, list[Message]] = Field(
        description="The messages exchanged for each task (task_id -> messages).",
        default_factory=dict
    )
    
    # Global messages for backward compatibility and agent's global view
    messages: list[Message] = Field(
        description="The global messages exchanged between the user, agent and environment."
    )
    
    trial: Optional[int] = Field(description="Trial number", default=None)
    seed: Optional[int] = Field(
        description="Seed used for the simulation.", default=None
    )


class Results(BaseModel):
    """
    Run results
    """

    timestamp: Optional[str] = Field(
        description="The timestamp of the simulation.", default_factory=get_now
    )
    info: Info = Field(description="Information.")
    tasks: list[Task] = Field(description="The list of tasks.")
    simulations: list[SimulationRun] = Field(description="The list of simulations.")

    @classmethod
    def load(cls, path: Path) -> "Results":
        with open(path, "r") as f:
            return cls.model_validate_json(f.read())

    def save(self, path: Path) -> None:
        """
        Save the results to a file.
        """
        with open(path, "w") as f:
            f.write(self.model_dump_json(indent=4))

    def to_df(self) -> pd.DataFrame:
        """
        Convert a Results object to a pandas DataFrame.
        """

        def transfer_only(task: Task) -> bool:
            """
            Check if the task is a transfer only task.
            """
            if task.evaluation_criteria is None:
                return False
            if task.evaluation_criteria.actions is None:
                return False
            actions = task.evaluation_criteria.actions
            if len(actions) != 1:
                return False
            action = actions[0]
            if "transfer" in action.name.lower():
                return True
            return False

        def get_task_metrics(task: Task) -> dict:
            eval_metrics = (
                task.evaluation_criteria.info()
                if task.evaluation_criteria is not None
                else {}
            )
            num_actions = (
                eval_metrics["num_agent_actions"] + eval_metrics["num_user_actions"]
            )
            if transfer_only(task):
                num_actions = -1
            info = {
                "task_num_agent_actions": eval_metrics["num_agent_actions"],
                "task_num_user_actions": eval_metrics["num_user_actions"],
                "task_num_actions": num_actions,
                "task_num_env_assertions": eval_metrics["num_env_assertions"],
                "task_num_nl_assertions": eval_metrics["num_nl_assertions"],
            }
            return info

        rows = []
        for sim in self.simulations:
            # Check if this is a multi-task simulation
            if hasattr(sim, 'task_ids') and sim.task_ids and len(sim.task_ids) > 1:
                # Multi-task simulation: create one row per task
                for i, task_id in enumerate(sim.task_ids):
                    # Get task-specific data
                    task_termination_reason = (
                        sim.task_termination_reasons[i] 
                        if hasattr(sim, 'task_termination_reasons') and len(sim.task_termination_reasons) > i
                        else sim.termination_reason
                    )
                    
                    # Get task-specific messages count
                    task_messages_count = (
                        len(sim.task_messages.get(task_id, []))
                        if hasattr(sim, 'task_messages') and sim.task_messages
                        else len(sim.messages)  # Fallback to global messages
                    )
                    
                    # Get task-specific reward info
                    task_reward = None
                    if hasattr(sim, 'task_rewards') and sim.task_rewards:
                        # Find reward for this specific task
                        task_reward_info = next(
                            (tr for tr in sim.task_rewards if isinstance(tr, dict) and tr.get('task_id') == task_id), 
                            None
                        )
                        if task_reward_info and isinstance(task_reward_info, dict) and 'reward_info' in task_reward_info:
                            reward_info = task_reward_info['reward_info']
                            if hasattr(reward_info, 'reward'):
                                task_reward = reward_info.reward
                    
                    # Fallback to simulation reward if no task-specific reward
                    if task_reward is None and sim.reward_info:
                        task_reward = sim.reward_info.reward
                        
                    row = {
                        "simulation_id": sim.id,
                        "task_id": task_id,
                        "trial": sim.trial,
                        "seed": sim.seed,
                        "reward": task_reward,
                        "agent_cost": sim.agent_cost,  # Shared across all tasks
                        "user_cost": sim.user_cost,   # Shared across all tasks
                        "termination_reason": task_termination_reason,
                        "duration": sim.duration,  # Shared across all tasks
                        "num_messages": task_messages_count,
                        "is_multi_task": True,
                        "task_index": i,
                        "total_tasks": len(sim.task_ids),
                        "info_git_commit": self.info.git_commit,
                        "info_seed": self.info.seed,
                        "info_num_trials": self.info.num_trials,
                        "info_max_steps": self.info.max_steps,
                        "info_max_errors": self.info.max_errors,
                        "info_domain": self.info.environment_info.domain_name,
                        "info_user_implementation": self.info.user_info.implementation,
                        "info_user_llm": self.info.user_info.llm,
                        "info_user_llm_args": self.info.user_info.llm_args,
                        "info_agent_implementation": self.info.agent_info.implementation,
                        "info_agent_llm": self.info.agent_info.llm,
                        "info_agent_llm_args": self.info.agent_info.llm_args,
                    }
                    
                    # Get task metrics
                    task = next((t for t in self.tasks if t.id == task_id), None)
                    if task:
                        row.update(get_task_metrics(task))
                    
                    rows.append(row)
            else:
                # Single-task simulation (backward compatibility)
                task_id = sim.task_id or (sim.task_ids[0] if hasattr(sim, 'task_ids') and sim.task_ids else None)
                
                row = {
                    "simulation_id": sim.id,
                    "task_id": task_id,
                    "trial": sim.trial,
                    "seed": sim.seed,
                    "reward": sim.reward_info.reward if sim.reward_info else None,
                    "agent_cost": sim.agent_cost,
                    "user_cost": sim.user_cost,
                    "termination_reason": sim.termination_reason,
                    "duration": sim.duration,
                    "num_messages": len(sim.messages),
                    "is_multi_task": False,
                    "task_index": 0,
                    "total_tasks": 1,
                    "info_git_commit": self.info.git_commit,
                    "info_seed": self.info.seed,
                    "info_num_trials": self.info.num_trials,
                    "info_max_steps": self.info.max_steps,
                    "info_max_errors": self.info.max_errors,
                    "info_domain": self.info.environment_info.domain_name,
                    "info_user_implementation": self.info.user_info.implementation,
                    "info_user_llm": self.info.user_info.llm,
                    "info_user_llm_args": self.info.user_info.llm_args,
                    "info_agent_implementation": self.info.agent_info.implementation,
                    "info_agent_llm": self.info.agent_info.llm,
                    "info_agent_llm_args": self.info.agent_info.llm_args,
                }
                
                # Get task metrics
                if task_id:
                    task = next((t for t in self.tasks if t.id == task_id), None)
                    if task:
                        row.update(get_task_metrics(task))
                
                rows.append(row)
        return pd.DataFrame(rows)
