import importlib

import click

from agentdojo.task_suite.load_suites import get_suite, get_suites
from agentdojo.task_suite.task_suite import TaskSuite


def check_suite(suite: TaskSuite, check_injectable: bool):
    print(f"\033[94mChecking suite: `{suite.name}`\033[00m")

    _, (user_tasks_results, injection_tasks_results) = suite.check(check_injectable)

    ok = True

    if all(result[0] for result in user_tasks_results.values()):
        print(f"\033[92mSuite '{suite.name}' passed for all {len(suite.user_tasks)} user tasks!\033[00m")
    else:
        print(f"\033[91mSuite '{suite.name}' failed for some user tasks!\033[00m")
        ok = False
    if all(result for result in injection_tasks_results.values()):
        print(f"\033[92mSuite '{suite.name}' passed for all {len(suite.injection_tasks)} injection tasks!\033[00m")
    else:
        print(f"\033[91mSuite '{suite.name}' failed for some injection tasks!\033[00m")
        ok = False

    for user_task_id, result in user_tasks_results.items():
        if not result[0]:
            print(f"\033[91mUser task '{user_task_id}' failed: {result[1]}\033[00m")

    for injection_task_id, result in injection_tasks_results.items():
        if not result:
            print(f"\033[91mThe ground truth does not solve the injection task task '{injection_task_id}'\033[00m")

    return ok


@click.command()
@click.option(
    "--suite",
    "-s",
    "suite_names",
    type=str,
    multiple=True,
    default=tuple(),
)
@click.option(
    "--benchmark-version",
    "-v",
    "benchmark_version",
    type=str,
    default="v1.1.2",
)
@click.option(
    "--module-to-load",
    "-ml",
    "modules_to_load",
    type=str,
    multiple=True,
    default=tuple(),
    help="The modules to load before running the benchmark. Useful for registering custom attacks, defenses, or suites.",
)
@click.option(
    "--no-check-injectable",
    type=bool,
    is_flag=True,
    help="Whether to *not* check if tasks are injectable. This is useful when running a utility-only benchmark.",
)
def main(
    suite_names: list[str],
    benchmark_version: str = "v1.1.2",
    modules_to_load: tuple[str, ...] = (),
    no_check_injectable: bool = False,
):
    for module in modules_to_load:
        importlib.import_module(module)

    if len(suite_names) == 0:
        suites = get_suites(benchmark_version).values()
    else:
        suites = [get_suite(benchmark_version, suite_name) for suite_name in suite_names]
    for suite in suites:
        check_suite(suite, not no_check_injectable)


if __name__ == "__main__":
    main()
