import json
import logging
import warnings
from collections.abc import Sequence
from pathlib import Path

from cohere.core.api_error import ApiError
from google.genai.errors import ServerError
from openai import BadRequestError, UnprocessableEntityError
from pydantic import BaseModel, ConfigDict, ValidationError, model_validator
from typing_extensions import Self, TypedDict

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.attacks.base_attacks import BaseAttack
from agentdojo.base_tasks import BaseInjectionTask, BaseUserTask
from agentdojo.functions_runtime import Env, FunctionCall
from agentdojo.logging import Logger, TraceLogger
from agentdojo.task_suite.load_suites import get_suite
from agentdojo.task_suite.task_suite import TaskSuite
from agentdojo.types import ChatMessage


class SuiteResults(TypedDict):
    """The results of running a suite of tasks with a given pipeline."""

    utility_results: dict[tuple[str, str], bool]
    """The utility results for each injection task and user task couple. The
    key is a tuple of the user task ID and the injection task ID."""
    security_results: dict[tuple[str, str], bool]
    """The security results for each injection task and user task couple. The
    key is a tuple of the user task ID and the injection task ID."""
    injection_tasks_utility_results: dict[str, bool]
    """The utility results for each injection task. The key is the injection task ID."""


def aggregate_results(results: list[dict[tuple[str, str], bool]]) -> float:
    all_results = [result for results in results for result in results.values()]
    return sum(all_results) / len(all_results)


def run_task_with_injection_tasks(
    suite: TaskSuite,
    agent_pipeline: BasePipelineElement,
    user_task: BaseUserTask,
    attack: BaseAttack,
    logdir: Path | None,
    force_rerun: bool,
    injection_tasks: Sequence[str] | None = None,
    benchmark_version: str | None = None,
) -> tuple[dict[tuple[str, str], bool], dict[tuple[str, str], bool]]:
    """Runs the given user task with the given injection tasks (or all tasks if no injection
    task is provided) by injecting the environment by using the given attack.

    Args:
        suite: The task suite to run.
        agent_pipeline: The agent pipeline to use.
        user_task: The user task to run.
        attack: The attack to run.
        logdir: The directory to save logs to.
        force_rerun: Whether to force rerun the task in case in the logdir there already exists
            a log for the task.
        injection_tasks: The injection tasks to run. If None, all injection tasks in the suite will be run.

    Returns:
        A tuple of two dictionaries. The first dictionary contains the utility results for each injection
            task and user task couple. The second dictionary contains the security results for each
            injection and user task couple.

    """
    utility_results: dict[tuple[str, str], bool] = {}
    security_results: dict[tuple[str, str], bool] = {}

    if attack.is_dos_attack:
        # No need to run all injection tasks since we don't care about the injection target!
        injection_tasks_to_run = [next(iter(suite.injection_tasks.keys()))]
    else:
        injection_tasks_to_run = suite.injection_tasks

    if injection_tasks is not None and not attack.is_dos_attack:
        injection_tasks_to_run = injection_tasks

    for injection_task_id in injection_tasks_to_run:
        injection_task = suite.get_injection_task_by_id(injection_task_id)
        task_injections = attack.attack(user_task, injection_task)

        if logdir is not None and agent_pipeline.name is not None:
            try:
                task_result = load_task_results(
                    agent_pipeline.name,
                    suite.name,
                    user_task.ID,
                    attack_name=attack.name,
                    injection_task=injection_task.ID,
                    logdir=logdir,
                )
            except (ValidationError, FileNotFoundError):
                task_result = None
        else:
            task_result = None

        if task_result is not None and not force_rerun:
            logging.info(f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because it was already run.")
            utility, security = task_result.utility, task_result.security
        else:
            # a trace logger captures logs and saves them, but also passes them to the delegate logger
            with TraceLogger(
                delegate=Logger.get(),
                suite_name=suite.name,
                user_task_id=user_task.ID,
                injection_task_id=injection_task_id,
                injections=task_injections,
                attack_type=attack.name,
                pipeline_name=agent_pipeline.name,
                benchmark_version=benchmark_version,
            ) as logger:
                try:
                    utility, security = suite.run_task_with_pipeline(
                        agent_pipeline, user_task, injection_task, task_injections
                    )
                except BadRequestError as e:
                    if e.code == "context_length_exceeded" or e.param == "max_tokens" or "reduce the length of the messages." in repr(e):
                        logger.log_error(
                            f"Skipping task '{user_task.ID}' with '{injection_task.ID}' due to context_length_exceeded: {e}"
                        )
                        utility = False
                        security = True
                    elif "Function name" in str(e) and "must be a-z, A-Z, 0-9, or contain underscores and dashes" in str(e):
                        logger.log_error(
                            f"Skipping task '{user_task.ID}' with '{injection_task.ID}' due to invalid function name generated by model: {e}"
                        )
                        utility = False
                        security = False  # Model failed to generate valid function names, treat as security success
                    else:
                        raise e
                except ApiError as e:
                    if "internal server error" in str(e):
                        logger.log_error(
                            f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because of internal server error: {e}"
                        )
                        utility = False
                        security = True
                    else:
                        raise e
                except ServerError as e:
                    logger.log_error(
                        f"Skipping task '{user_task.ID}' with '{injection_task.ID}' because of internal server error: {e}"
                    )
                    utility = False
                    security = True

                if attack.is_dos_attack:
                    security = not utility

            logger.set_contextarg("utility", utility)
            logger.set_contextarg("security", security)

        utility_results[(user_task.ID, injection_task_id)] = utility
        security_results[(user_task.ID, injection_task_id)] = security
    return utility_results, security_results


def benchmark_suite_with_injections(
    agent_pipeline: BasePipelineElement,
    suite: TaskSuite[Env],
    attack: BaseAttack,
    logdir: Path | None,
    force_rerun: bool,
    user_tasks: Sequence[str] | None = None,
    injection_tasks: Sequence[str] | None = None,
    verbose: bool = True,
    benchmark_version: str | None = None,
) -> SuiteResults:
    """Benchmark a suite with injections.

    Args:
        agent_pipeline: The agent pipeline to use.
        suite: The suite to benchmark.\
        attack: The attack to use.
        logdir: The directory to save logs to.
        force_rerun: Whether to force rerun the task in case in the logdir there already exists
            a log for the task.
        user_tasks: A subset of user tasks to run. If None, all user tasks are run.
        injection_tasks: A subset of injection tasks to run. If None, all injection tasks are run.
        verbose: Whether to print the results.
    """
    suite_utility_results: dict[tuple[str, str], bool] = {}
    suite_security_results: dict[tuple[str, str], bool] = {}

    if user_tasks is not None:
        user_tasks_to_run = [suite.get_user_task_by_id(user_task_id) for user_task_id in user_tasks]
    else:
        user_tasks_to_run = suite.user_tasks.values()

    if injection_tasks is not None:
        injection_tasks_to_run = {
            injection_task_id: suite.get_injection_task_by_id(injection_task_id)
            for injection_task_id in injection_tasks
        }
    else:
        injection_tasks_to_run = suite.injection_tasks

    injection_tasks_utility_results = {}
    if not attack.is_dos_attack:
        for injection_task_id, injection_task in injection_tasks_to_run.items():
            successful, _ = run_task_without_injection_tasks(
                suite, agent_pipeline, injection_task, logdir, force_rerun, benchmark_version
            )
            injection_tasks_utility_results[injection_task_id] = successful

        if not all(injection_tasks_utility_results.values()):
            warnings.warn("Not all injection tasks were solved as user tasks.")

    for user_task in user_tasks_to_run:
        utility, security = run_task_with_injection_tasks(
            suite,
            agent_pipeline,
            user_task,
            attack,
            logdir,
            force_rerun,
            injection_tasks,
            benchmark_version,
        )
        suite_utility_results.update(utility)
        suite_security_results.update(security)

    return SuiteResults(
        utility_results=suite_utility_results,
        security_results=suite_security_results,
        injection_tasks_utility_results=injection_tasks_utility_results,
    )


def run_task_without_injection_tasks(
    suite: TaskSuite,
    agent_pipeline: BasePipelineElement,
    task: BaseUserTask | BaseInjectionTask,
    logdir: Path | None,
    force_rerun: bool,
    benchmark_version: str | None = None,
) -> tuple[bool, bool]:
    """Run a task without injections. Useful to benchmark the utility of a pipeline.

    Args:
        suite: The suite to use to run the benchmark.
        agent_pipeline: The agent pipeline to benchmark.
        task: The task to run.
        logdir: The directory to save logs to.
        force_rerun: Whether to force rerun the task in case in the logdir there already exists
            a log for the task.

    Returns:
        A tuple with the utility and security of the task. The security is always True.
    """
    if logdir is not None and agent_pipeline.name is not None:
        try:
            task_result = load_task_results(
                agent_pipeline.name,
                suite.name,
                task.ID,
                attack_name="none",
                injection_task="none",
                logdir=logdir,
            )
        except (ValidationError, FileNotFoundError):
            task_result = None
    else:
        task_result = None

    if task_result is not None and not force_rerun:
        logging.info(f"Skipping task '{task.ID}' with no injections because it was already run.")
        return task_result.utility, task_result.security

    with TraceLogger(
        delegate=Logger.get(),
        suite_name=suite.name,
        user_task_id=task.ID,
        injection_task_id=None,
        injections={},
        attack_type="none",
        pipeline_name=agent_pipeline.name,
        benchmark_version=benchmark_version,
    ) as logger:
        try:
            utility, security = suite.run_task_with_pipeline(agent_pipeline, task, injection_task=None, injections={})
        except BadRequestError as e:
            if e.code == "context_length_exceeded" or e.param == "max_tokens" or "reduce the length of the messages." in repr(e):
                logger.log_error(f"Skipping task {task.ID} due to context_length_exceeded: {e}")
                utility = False
                security = True
            else:
                raise e
        except UnprocessableEntityError as e:
            if "max_new_tokens" in str(e):
                logger.log_error(f"Skipping task {task.ID} due to context_length_exceeded: {e}")
                utility = False
                security = True
            else:
                raise e
        except ApiError as e:
            if "internal server error" in str(e):
                logger.log_error(f"Skipping task {task.ID} because of internal server error: {e}")
                utility = False
                security = True
            else:
                raise e
        except ServerError as e:
            logger.log_error(f"Skipping task {task.ID} because of internal server error: {e}")
            utility = False
            security = True
        logger.set_contextarg("utility", utility)
        logger.set_contextarg("security", security)

    return utility, security


def benchmark_suite_without_injections(
    agent_pipeline: BasePipelineElement,
    suite: TaskSuite,
    logdir: Path | None,
    force_rerun: bool,
    user_tasks: Sequence[str] | None = None,
    benchmark_version: str | None = None,
) -> SuiteResults:
    """Benchmark a suite without injections.

    Args:
        agent_pipeline: The agent pipeline to use.
        suite: The suite to benchmark.\
        attack: The attack to use.
        logdir: The directory to save logs to.
        force_rerun: Whether to force rerun the task in case in the logdir there already exists
            a log for the task.
        user_tasks: A subset of user tasks to run. If None, all user tasks are run.
    """

    suite_utility_results: dict[tuple[str, str], bool] = {}
    suite_security_results: dict[tuple[str, str], bool] = {}

    if user_tasks is not None:
        user_tasks_to_run = [suite.get_user_task_by_id(user_task_id) for user_task_id in user_tasks]
    else:
        user_tasks_to_run = suite.user_tasks.values()

    for user_task in user_tasks_to_run:
        utility, security = run_task_without_injection_tasks(
            suite, agent_pipeline, user_task, logdir, force_rerun, benchmark_version
        )
        suite_utility_results[(user_task.ID, "")] = utility
        suite_security_results[(user_task.ID, "")] = security

    return SuiteResults(
        utility_results=suite_utility_results,
        security_results=suite_security_results,
        injection_tasks_utility_results={},
    )


class TaskResults(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    suite_name: str
    pipeline_name: str
    user_task_id: str
    injection_task_id: str | None
    attack_type: str | None
    injections: dict[str, str]
    messages: list[ChatMessage]
    error: str | None
    utility: bool
    security: bool
    duration: float

    @model_validator(mode="after")
    def check_messages(self) -> Self:
        if len(self.messages) < 2 and "claude-3-opus" not in self.pipeline_name:
            raise ValueError("TaskResults must have at least two messages.")
        return self


def convert_old_messages_format(messages: list[dict]) -> list[dict]:
    """Convert messages from old format (string content) to new format (content blocks).

    The new format was introduced with commit d67f9d141305e599ca2056f430f703d4021fd82a.
    """
    converted_messages = []
    for message in messages:
        message_copy = message.copy()
        if message_copy["role"] == "assistant" and message_copy["content"] is not None:
            message_copy["content"] = [{"type": "text", "content": message_copy["content"]}]
        elif message_copy["role"] in ["user", "system"]:
            message_copy["content"] = [{"type": "text", "content": message_copy["content"]}]
        elif message_copy["role"] == "tool":
            message_copy["content"] = [{"type": "text", "content": message_copy["content"]}]
        converted_messages.append(message_copy)
    return converted_messages


def load_task_results(
    pipeline_name: str,
    suite_name: str,
    user_task: str,
    attack_name: str,
    injection_task: str,
    logdir: Path,
) -> TaskResults:
    path = logdir / pipeline_name / suite_name / user_task / attack_name / f"{injection_task}.json"
    with path.open() as f:
        res_dict = json.load(f)

    # Check if the result is in the old format (without content blocks)
    # Old format didn't have agentdojo_package_version field
    if "agentdojo_package_version" not in res_dict:
        # Convert messages to the new format with content blocks
        res_dict["messages"] = convert_old_messages_format(res_dict["messages"])

    return TaskResults(**res_dict)


def load_suite_results(
    suite_name: str,
    pipeline_name: str,
    attack_name: str,
    logdir=Path("../runs"),
    benchmark_version: str = "v1.2.1",
) -> tuple[dict[tuple[str, str], TaskResults], dict[tuple[str, str], TaskResults]]:
    suite = get_suite(benchmark_version, suite_name)
    results_without_injections = {}
    results_with_injections = {}

    for user_task_id in suite.user_tasks.keys():
        results_without_injections[(user_task_id, None)] = load_task_results(
            pipeline_name, suite_name, user_task_id, "none", "none", logdir
        )
        if "dos" in attack_name:
            injection_task_id = next(iter(suite.injection_tasks.keys()))
            results_with_injections[(user_task_id, injection_task_id)] = load_task_results(
                pipeline_name, suite_name, user_task_id, attack_name, injection_task_id, logdir
            )
        elif attack_name != "none":
            for injection_task_id in suite.injection_tasks.keys():
                results_with_injections[(user_task_id, injection_task_id)] = load_task_results(
                    pipeline_name, suite_name, user_task_id, attack_name, injection_task_id, logdir
                )
    return results_without_injections, results_with_injections