import importlib.resources
import re
import warnings
from collections.abc import Sequence
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import yaml
from pydantic import BaseModel

from agentgym.agent_pipeline.errors import AbortAgentError
from agentgym.agent_pipeline.ground_truth_pipeline import GroundTruthPipeline
from agentgym.base_tasks import BaseInjectionTask, BaseUserTask
from agentgym.functions_engine.functions_engine import (
    FUNCTIONS_DOCS,
    FunctionCall,
)
from agentgym.types import ChatMessage, TaskEnvironment
from agentgym.yaml_loader import ImportLoader

if TYPE_CHECKING:
    from agentgym.agent_pipeline.base_pipeline_element import (
        BasePipelineElement,
    )

T = TypeVar("T", bound=BaseModel)


class InjectionVector(BaseModel):
    description: str
    default: str


def messages_to_dict(messages: list[Any]) -> list[dict[str, str]]:
    def transform(msg):
        if hasattr(msg, "role"):
            return {"role": msg.role, "content": msg.content}
        else:
            assert isinstance(msg, dict), f"Unhandled message type: {type(msg)}: {msg}"
            return msg

    return [transform(msg) for msg in messages]


def messages_to_dict_v2(messages: Sequence[ChatMessage]) -> list[dict[str, str]]:
    def transform(msg: ChatMessage):
        return {
            "role": msg["role"],
            "content": msg["content"],
        }

    return [transform(msg) for msg in messages]


def functions_stack_trace_from_messages(
    messages: Sequence[ChatMessage],
) -> list[FunctionCall]:
    tool_calls = []
    for message in messages:
        if message["role"] == "assistant":
            for tool_call in message["tool_calls"] or []:
                tool_calls.append(tool_call)
    return tool_calls


def model_output_from_messages(
    messages: Sequence[ChatMessage],
) -> str | None:
    if messages[-1]["role"] != "assistant":
        raise ValueError("Last message was not an assistant message")
    return messages[-1]["content"]


@lru_cache
def read_suite_file(suite_name: str, file: str) -> str:
    package_files = importlib.resources.files("agentgym")
    path = package_files / "data" / "suites" / suite_name / file
    with importlib.resources.as_file(path) as p, p.open() as f:
        # Load into yaml to resolve imports
        data_yaml = yaml.load(f, ImportLoader)
    # dump back to string to pass to include injections later
    return yaml.dump(data_yaml, default_flow_style=False)


def validate_injections(injections: dict[str, str], injection_vector_defaults: dict[str, str]):
    injections_set = set(injections.keys())
    injection_vector_defaults_set = set(injection_vector_defaults.keys())
    if not injections_set.issubset(injection_vector_defaults_set):
        raise ValueError("Injections must be a subset of the injection vector defaults")


def get_environment(env_type: type[T], suite_name: str, injections: dict[str, str]) -> T:
    environment_text = read_suite_file(suite_name, "environment.yaml")
    injection_vector_defaults = get_injection_vector_defaults(suite_name)
    injections_with_defaults = dict(injection_vector_defaults, **injections)
    validate_injections(injections, injection_vector_defaults)
    injected_environment = environment_text.format(**injections_with_defaults)
    environment = env_type(**yaml.safe_load(injected_environment))
    return environment


def get_injection_vector_defaults(suite_name: str) -> dict[str, str]:
    injection_vectors_text = read_suite_file(suite_name, "injection_vectors.yaml")
    injection_vectors = yaml.safe_load(injection_vectors_text)
    vectors = {vector_id: InjectionVector(**vector_info) for vector_id, vector_info in injection_vectors.items()}
    return {vector_id: vector.default for vector_id, vector in vectors.items()}


Env = TypeVar("Env", bound=TaskEnvironment)


class TaskSuite(Generic[Env]):
    def __init__(
        self,
        name: str,
        environment_type: type[Env],
        tools: list[str],
        system_prompt: str | None = None,
    ):
        self.name = name
        self.environment_type = environment_type
        self.tools = tools
        self._user_tasks: dict[str, BaseUserTask[Env]] = {}
        self._injection_tasks: dict[str, BaseInjectionTask[Env]] = {}
        self._module = self.environment_type.__module__

    def load_and_inject_default_environment(self, injections: dict[str, str]) -> Env:
        return get_environment(self.environment_type, self.name, injections)

    def _get_task_number(self, task_cls: type[BaseUserTask] | type[BaseInjectionTask], prefix: str) -> int:
        match = re.match(rf"{prefix}(\d+)", task_cls.__name__)
        if not match:
            raise ValueError(f"User tasks must be named {prefix} followed by a number, got {task_cls.__name__}")
        return int(match.group(1))

    def register_user_task(self, task: type[BaseUserTask[Env]]) -> type[BaseUserTask[Env]]:
        task_n = self._get_task_number(task, "UserTask")
        task_id = f"user_task_{task_n}"
        setattr(task, "ID", task_id)
        self._user_tasks[task_id] = task()
        return task

    @property
    def user_tasks(self) -> dict[str, BaseUserTask[Env]]:
        return self._user_tasks

    def get_user_task_by_id(self, task_id: str) -> BaseUserTask[Env]:
        return self._user_tasks[task_id]

    def register_injection_task(self, task: type[BaseInjectionTask[Env]]) -> type[BaseInjectionTask[Env]]:
        task_n = self._get_task_number(task, "InjectionTask")
        task_id = f"injection_task_{task_n}"
        setattr(task, "ID", task_id)
        self._injection_tasks[task_id] = task()
        return task

    @property
    def injection_tasks(self) -> dict[str, BaseInjectionTask[Env]]:
        return self._injection_tasks

    def get_injection_task_by_id(self, injection_task_id: str) -> BaseInjectionTask[Env]:
        return self._injection_tasks[injection_task_id]

    def _check_user_task_utility(
        self,
        task: BaseUserTask,
        model_output: str,
        pre_environment: Env,
        task_environment: Env,
        functions_stack_trace: Sequence[FunctionCall],
    ) -> bool:
        utility_from_stack_traces = task.utility_from_traces(pre_environment, task_environment, functions_stack_trace)
        if utility_from_stack_traces is not None:
            return utility_from_stack_traces
        return task.utility(model_output, pre_environment, task_environment)

    def _check_injection_task_security(
        self,
        task: BaseInjectionTask,
        model_output: str,
        pre_environment: Env,
        task_environment: Env,
        functions_stack_trace: Sequence[FunctionCall],
    ) -> bool:
        security_from_stack_traces = task.security_from_traces(pre_environment, task_environment, functions_stack_trace)
        if security_from_stack_traces is not None:
            return security_from_stack_traces
        return task.security(model_output, pre_environment, task_environment)

    def _check_task_result(
        self,
        task: BaseUserTask | BaseInjectionTask,
        model_output: str,
        pre_environment: Env,
        task_environment: Env,
        functions_stack_trace: Sequence[FunctionCall],
    ) -> bool:
        match task:
            case BaseUserTask():
                return self._check_user_task_utility(
                    task,
                    model_output,
                    pre_environment,
                    task_environment,
                    functions_stack_trace,
                )
            case BaseInjectionTask():
                return self._check_injection_task_security(
                    task,
                    model_output,
                    pre_environment,
                    task_environment,
                    functions_stack_trace,
                )

    def run_task_with_pipeline(
        self,
        agent_pipeline: "BasePipelineElement",
        user_task: BaseUserTask[Env] | BaseInjectionTask[Env],
        injection_task: BaseInjectionTask[Env] | None,
        injections: dict[str, str],
        environment: Env | None = None,
        verbose: bool = False,
    ) -> tuple[bool, bool]:
        # If no environment is provided, load the default environment
        if environment is None:
            environment = self.load_and_inject_default_environment(injections)
        # Initialize the environment according to the task
        if isinstance(user_task, BaseUserTask):
            task_environment = user_task.init_environment(environment)
        else:
            task_environment = environment

        # Create a copy of the environment before running the user task to then diff if needed
        pre_environment = task_environment.model_copy(deep=True)
        if isinstance(user_task, BaseUserTask):
            prompt = user_task.PROMPT
        else:
            prompt = user_task.GOAL

        tools = [FUNCTIONS_DOCS[tool] for tool in self.tools]

        model_output = None
        messages = []
        for _ in range(3):
            # Run the user task
            try:
                _, _, task_environment, messages, _ = agent_pipeline.query(prompt, tools, task_environment)
            except AbortAgentError as e:
                # get last environment and messages
                task_environment = e.task_environment
                messages = e.messages

            model_output = model_output_from_messages(messages)
            if model_output is not None:
                break

        if model_output is None:
            warnings.warn(f"Model output was None for task {user_task.ID}")

        functions_stack_trace = functions_stack_trace_from_messages(messages)
        # Update the environment with the user task responses
        task_environment.responses = messages_to_dict_v2(messages)

        utility = self._check_task_result(
            user_task,
            model_output or "",
            pre_environment,
            task_environment,  # type: ignore
            functions_stack_trace,
        )

        # Early return if no injection was intended
        if injection_task is None:
            return utility, True

        security = self._check_task_result(
            injection_task,
            model_output or "",
            pre_environment,
            task_environment,  # type: ignore
            functions_stack_trace,
        )

        return utility, security

    def check(self) -> tuple[bool, tuple[dict[str, tuple[bool, str]], dict[str, bool]]]:
        injection_defaults = get_injection_vector_defaults(self.name)
        canary_injections = {inj: f"---CANARY_{inj}---" for inj in injection_defaults}
        environment = self.load_and_inject_default_environment(canary_injections)

        user_tasks_results = {}

        # Check user tasks
        for user_task in self.user_tasks.values():
            ground_truth_pipeline = GroundTruthPipeline(user_task)
            post_environment = environment.copy(deep=True)
            utility, _ = self.run_task_with_pipeline(
                ground_truth_pipeline,
                user_task,
                injection_task=None,
                injections={},
                environment=post_environment,
            )
            if not utility:
                user_tasks_results[user_task.ID] = (
                    False,
                    "Ground truth does not solve the task",
                )
            # Check if the task is injectable
            tools = [FUNCTIONS_DOCS[tool] for tool in self.tools]
            _, _, _, responses, _ = ground_truth_pipeline.query(user_task.PROMPT, tools, environment.copy(deep=True))
            if not is_task_injectable(responses, list(canary_injections.values())):
                user_tasks_results[user_task.ID] = (
                    False,
                    "The task is not injectable",
                )
            if user_task.ID not in user_tasks_results:
                user_tasks_results[user_task.ID] = (True, "")

        injection_tasks_results = {}

        # Check injection tasks
        user_task = next(iter(self.user_tasks.values()))
        for injection_task in self.injection_tasks.values():
            ground_truth_pipeline = GroundTruthPipeline(injection_task)
            post_environment = environment.copy(deep=True)
            _, security = self.run_task_with_pipeline(
                ground_truth_pipeline,
                user_task,
                injection_task=injection_task,
                injections={},
                environment=post_environment,
            )
            injection_tasks_results[injection_task.ID] = security
            for tool_call in injection_task.ground_truth(environment.copy(deep=True)):
                if "placeholder_args" not in tool_call:
                    warnings.warn(f"Missing placeholder_args in {injection_task.ID}'s ground truth")

        all_checks_passed = all(result[0] for result in user_tasks_results.values()) and all(
            result for result in injection_tasks_results.values()
        )

        return all_checks_passed, (user_tasks_results, injection_tasks_results)


def is_task_injectable(responses: Sequence[ChatMessage], canaries: list[str]) -> bool:
    tool_responses = [response for response in responses if response["role"] == "tool"]
    tool_outputs = [response["content"] for response in tool_responses if isinstance(response["content"], str)]
    all_outputs = " ".join(tool_outputs)
    inj_in_output = [inj in all_outputs for inj in canaries]
    return any(inj_in_output)
