import logging
import warnings
from itertools import repeat
from multiprocessing import Pool
from pathlib import Path

import click
from agentgym import make_pipeline
from agentgym.agent_pipeline.make_pipeline import DEFENSES, ModelsEnum
from agentgym.attacks import AttackersEnum, make_attacker
from agentgym.benchmark import (
    SuiteResults,
    benchmark_suite_with_injections,
    benchmark_suite_without_injections,
)
from agentgym.logging import OutputLogger
from agentgym.task_suites import SUITES
from dotenv import load_dotenv
from rich import print
from rich.live import Live
from rich.logging import RichHandler

format = "%(message)s"
logging.basicConfig(
    format=format,
    level=logging.INFO,
    datefmt="%H:%M:%S",
    handlers=[RichHandler(show_path=False, markup=True)],
)


def benchmark_suite(
    suite_name: str,
    model: ModelsEnum,
    logdir: Path,
    force_rerun: bool,
    user_tasks: tuple[str, ...] = (),
    injection_tasks: tuple[str, ...] = (),
    attacker: AttackersEnum | None = None,
    defense: str | None = None,
    system_message_name: str | None = None,
    system_message: str | None = None,
    live: Live | None = None,
) -> SuiteResults:
    if not load_dotenv(".env"):
        warnings.warn("No .env file found")
    if suite_name not in SUITES:
        raise ValueError(f"Unknown suite: '{suite_name}'")

    print(f"Running benchmark for suite: '{suite_name}'")
    print(f"Using model: '{model}'")
    if attacker is not None:
        print(f"Using attacker: '{attacker}'")
    if defense is not None:
        print(f"Using defense: '{defense}'")
    if len(user_tasks) > 0:
        print(f"Using user tasks: {', '.join(user_tasks)}")

    suite = SUITES[suite_name]
    pipeline = make_pipeline(model, defense, system_message_name, system_message)
    with OutputLogger(str(logdir), live=live):
        if attacker is None:
            results = benchmark_suite_without_injections(
                pipeline,
                suite,
                user_tasks=user_tasks if len(user_tasks) != 0 else None,
                logdir=logdir,
                force_rerun=force_rerun,
            )
        else:
            attacker_ = make_attacker(attacker, suite, pipeline)
            results = benchmark_suite_with_injections(
                pipeline,
                suite,
                attacker_,
                user_tasks=user_tasks if len(user_tasks) != 0 else None,
                injection_tasks=injection_tasks if len(injection_tasks) != 0 else None,
                logdir=logdir,
                force_rerun=force_rerun,
            )
    print(f"Finished benchmark for suite: '{suite_name}'")

    return results


def show_results(suite_name: str, results: SuiteResults, show_security_results: bool):
    utility_results = results["utility_results"].values()
    avg_utility = sum(utility_results) / len(utility_results)

    print(f"Results for suite {suite_name}")
    print(f"Average utility: {avg_utility * 100:.2f}%")

    if show_security_results:
        passed_injection_tasks = sum(results["injection_tasks_utility_results"].values())
        total_injection_tasks = len(results["injection_tasks_utility_results"])
        print(f"\nPassed injection tasks as user tasks: {passed_injection_tasks}/{total_injection_tasks}")

        security_results = results["security_results"].values()
        avg_security = sum(security_results) / len(security_results)
        print(f"Average security: {avg_security * 100:.2f}%")


@click.command()
@click.option(
    "--model",
    default="gpt-4o-2024-05-13",
    type=click.Choice([value for value in ModelsEnum]),
    help="The language model to benchmark. Defaults to `gpt-4o-2024-05-13`.",
)
@click.option(
    "--logdir",
    default="./runs",
    type=Path,
    help="The directory to save logs. Defaults to `./runs`.",
)
@click.option(
    "--attacker",
    type=click.Choice([value for value in AttackersEnum]),
    default=None,
    help="The attacker to use. None by default.",
)
@click.option(
    "--defense",
    type=click.Choice(DEFENSES),
    default=None,
    help="The defense to use. None by default.",
)
@click.option(
    "--system-message-name",
    type=str,
    default=None,
    help="The name of the system message to use among the default ones in `data/system_messages.yaml`.",
)
@click.option(
    "--system-message",
    type=str,
    default=None,
    help="The system message to use (as a string). If provided, `--system-message-name` is ignored.",
)
@click.option(
    "--user-task",
    "-ut",
    "user_tasks",
    type=str,
    multiple=True,
    default=tuple(),
    help="The user tasks to benchmark. If not provided, all tasks in the suite are run.",
)
@click.option(
    "--injection-task",
    "-it",
    "injection_tasks",
    type=str,
    multiple=True,
    default=tuple(),
    help="The user tasks to benchmark. If not provided, all tasks in the suite are run.",
)
@click.option(
    "--suite",
    "-s",
    "suites",
    type=click.Choice([key for key in SUITES]),
    multiple=True,
    default=tuple(),
    help="The user tasks to benchmark. If not provided, all tasks in the suite are run.",
)
@click.option(
    "--max-workers",
    type=int,
    default=1,
    help="How many suites can be benchmarked in parallel through multithreading. Nothing is parallelized by default.",
)
@click.option(
    "--force-rerun",
    "-f",
    is_flag=True,
    help="Whether to re-run tasks that have already been run.",
)
def main(
    suites: tuple[str, ...],
    model: ModelsEnum,
    logdir: Path,
    user_tasks: tuple[str, ...] = (),
    injection_tasks: tuple[str, ...] = (),
    attacker: AttackersEnum | None = None,
    defense: str | None = None,
    system_message_name: str | None = None,
    system_message: str | None = None,
    max_workers: int = 1,
    force_rerun: bool = False,
):
    if len(suites) == 0:
        suites = tuple(SUITES.keys())

    if len(suites) != 1:
        print(f"Benchmarking suites {', '.join(suites)}")

    if len(suites) != 1 and len(user_tasks) != 0:
        raise ValueError("A user task can be specified only when one suite is being executed")

    if max_workers == 1:
        results = {}
        for suite_name in suites:
            results[suite_name] = benchmark_suite(
                suite_name,
                model,
                logdir,
                user_tasks=user_tasks,
                injection_tasks=injection_tasks,
                attacker=attacker,
                defense=defense,
                system_message_name=system_message_name,
                system_message=system_message,
                force_rerun=force_rerun,
            )
        for suite_name, result in results.items():
            show_results(suite_name, result, attacker is not None)
        exit(0)

    with Pool(max_workers) as p:
        arguments = zip(
            suites,
            repeat(model),
            repeat(logdir),
            repeat(user_tasks),
            repeat(injection_tasks),
            repeat(attacker),
            repeat(defense),
            repeat(system_message_name),
            repeat(system_message),
            repeat(force_rerun),
        )
        results = p.starmap(benchmark_suite, arguments)

    for suite_name, result in zip(suites, results):
        show_results(suite_name, result, attacker is not None)


if __name__ == "__main__":
    main()
