import dataclasses
import datetime
import re
from pathlib import Path

import click
import pandas as pd

from agentdojo.benchmark import load_suite_results
from agentdojo.models import MODEL_PROVIDERS
from agentdojo.task_suite import get_suites
from agentdojo.types import ChatMessage, get_text_content_as_str


def find_information_tags(text: str) -> list[tuple[int, int, int]]:
    pattern = re.compile(r"<INFORMATION>.*?</INFORMATION>", re.DOTALL)
    matches = [(match.start(), match.end(), len(text)) for match in pattern.finditer(text)]
    return matches


def get_injection_position_in_tool_outputs(messages: list[ChatMessage]) -> list[list[tuple[int, int, int]]]:
    injection_positions = []
    for message in messages:
        if message["role"] == "tool":
            matches = find_information_tags(get_text_content_as_str(message["content"]))
            if len(matches) == 0:
                continue
            injection_positions.append(matches)
    return injection_positions


@dataclasses.dataclass
class EvalInfo:
    model: str
    date: datetime.date
    attack: str = "important_instructions"
    defense: str | None = None
    traceuid: str | None = None


def make_dataframe(eval_infos: list[EvalInfo], benchmark_version: str, logdir: Path) -> pd.DataFrame:
    rows = []
    for eval_info in eval_infos:
        model = eval_info.model.replace("/", "_")
        pipeline = f"{model}" if eval_info.defense is None else f"{model}-{eval_info.defense}"
        attack = eval_info.attack if eval_info.attack is not None else ""
        for suite in get_suites(benchmark_version):
            results_without, results_with = load_suite_results(
                suite, pipeline, attack, logdir=logdir, benchmark_version=benchmark_version
            )
            for (user_task, injection_task), result in results_without.items():
                injections_positions = get_injection_position_in_tool_outputs(result.messages)
                injection_seen = len(injections_positions) > 0
                times_injection_seen = sum(len(positions) for positions in injections_positions)
                if len(injections_positions) == 1 and len(injections_positions[0]) == 1:
                    injection_start, injection_end, output_len = injections_positions[0][0]
                    start_injection_position = injection_start / output_len
                    end_injection_position = injection_end / output_len
                else:
                    start_injection_position = None
                    end_injection_position = None
                rows.append(
                    {
                        "model": eval_info.model,
                        "defense": eval_info.defense,
                        "attack": None,
                        "suite": suite,
                        "user_task": user_task,
                        "injection_task": injection_task,
                        "utility": result.utility,
                        "security": result.security,
                        "injection_seen": injection_seen,
                        "times_injection_seen": times_injection_seen,
                        "start_injection_position": start_injection_position,
                        "end_injection_position": end_injection_position,
                        "eval_date": eval_info.date,
                        "traceuid": eval_info.traceuid,
                    }
                )
            for (user_task, injection_task), result in results_with.items():
                injections_positions = get_injection_position_in_tool_outputs(result.messages)
                injection_seen = len(injections_positions) > 0
                times_injection_seen = sum(len(positions) for positions in injections_positions)
                if len(injections_positions) == 1 and len(injections_positions[0]) == 1:
                    injection_start, injection_end, output_len = injections_positions[0][0]
                    start_injection_position = injection_start / output_len
                    end_injection_position = injection_end / output_len
                else:
                    start_injection_position = None
                    end_injection_position = None
                rows.append(
                    {
                        "model": eval_info.model,
                        "defense": eval_info.defense,
                        "attack": eval_info.attack,
                        "suite": suite,
                        "user_task": user_task,
                        "injection_task": injection_task,
                        "utility": result.utility,
                        "security": result.security,
                        "injection_seen": injection_seen,
                        "times_injection_seen": times_injection_seen,
                        "start_injection_position": start_injection_position,
                        "end_injection_position": end_injection_position,
                        "eval_date": eval_info.date,
                        "traceuid": eval_info.traceuid,
                    }
                )
    return pd.DataFrame(rows)


TABLE_MODELS = [
    EvalInfo(
        "claude-3-7-sonnet-20250219",
        datetime.date(2025, 2, 24),
        traceuid=None,
    ),
    EvalInfo(
        "claude-3-5-sonnet-20241022",
        datetime.date(2024, 11, 15),
        traceuid="invariant/agentdojo-claude-3-5-sonnet-20241022",
    ),
    EvalInfo(
        "claude-3-5-sonnet-20240620",
        datetime.date(2024, 6, 20),
        traceuid="invariant/agentdojo-claude-3-5-sonnet-20240620",
    ),
    EvalInfo("gpt-4o-2024-05-13", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gpt-4o-2024-05-13"),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        defense="transformers_pi_detector",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13-transformers_pi_detector",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        defense="tool_filter",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13-tool_filter",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        defense="repeat_user_prompt",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13-repeat_user_prompt",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        defense="spotlighting_with_delimiting",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13-spotlighting_with_delimiting",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        attack="injecagent",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        attack="tool_knowledge",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        attack="ignore_previous",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13",
    ),
    EvalInfo(
        "gpt-4o-2024-05-13",
        datetime.date(2024, 6, 5),
        attack="direct",
        traceuid="invariant/agentdojo-gpt-4o-2024-05-13",
    ),
    EvalInfo("gpt-4-0125-preview", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gpt-4-0125-preview"),
    EvalInfo(
        "gpt-4o-mini-2024-07-18", datetime.date(2024, 7, 19), traceuid="invariant/agentdojo-gpt-4o-mini-2024-07-18"
    ),
    EvalInfo(
        "meta-llama/Llama-3-70b-chat-hf",
        datetime.date(2021, 9, 5),
        traceuid="invariant/agentdojo-meta-llama_Llama-3-70b-chat-hf",
    ),
    EvalInfo("gpt-3.5-turbo-0125", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gpt-3-5-turbo-0125"),
    EvalInfo(
        "gpt-4-turbo-2024-04-09", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gpt-4-turbo-2024-04-09"
    ),
    EvalInfo(
        "claude-3-opus-20240229", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-claude-3-opus-20240229"
    ),
    EvalInfo(
        "claude-3-sonnet-20240229", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-claude-3-sonnet-20240229"
    ),
    EvalInfo(
        "claude-3-haiku-20240307", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-claude-3-haiku-20240307"
    ),
    EvalInfo("command-r-plus", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-command-r-plus"),
    EvalInfo("command-r", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-command-r"),
    EvalInfo("gemini-1.5-pro-002", datetime.date(2024, 11, 4), traceuid="invariant/agentdojo-gemini-1-5-pro-002"),
    EvalInfo("gemini-1.5-pro-001", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gemini-1-5-pro-001"),
    EvalInfo("gemini-1.5-flash-002", datetime.date(2024, 11, 5), traceuid="invariant/agentdojo-gemini-1-5-flash-002"),
    EvalInfo("gemini-1.5-flash-001", datetime.date(2024, 6, 5), traceuid="invariant/agentdojo-gemini-1-5-flash-001"),
    EvalInfo("gemini-2.0-flash-exp", datetime.date(2024, 12, 11), traceuid=None),
    EvalInfo("gemini-2.0-flash-001", datetime.date(2025, 1, 30), traceuid=None),
]


@click.command()
@click.argument("logdir", type=Path, default=Path("runs"))
@click.argument("benchmark_version", type=str, default="v1.1.2")
def main(logdir: Path, benchmark_version: str) -> None:
    full_results_df = make_dataframe(TABLE_MODELS, benchmark_version, logdir)
    results_df = full_results_df[["model", "defense", "attack", "utility", "security", "eval_date", "traceuid"]]
    no_attack_df: pd.DataFrame = results_df[results_df["attack"].isna()]  # type: ignore
    attack_df: pd.DataFrame = results_df[results_df["attack"].notna()]  # type: ignore

    stats_under_attack = (
        attack_df.groupby(["model", "defense", "attack", "eval_date", "traceuid"], dropna=False).mean().reset_index()
    )
    stats_under_attack = stats_under_attack.rename(
        columns={"utility": "Utility under attack", "security": "Targeted ASR"}
    )
    stats_no_attack = (
        no_attack_df.groupby(["model", "defense", "eval_date", "traceuid"], dropna=False).mean().reset_index()
    )
    stats_no_attack = stats_no_attack.rename(columns={"utility": "Utility"})

    overall_results = pd.merge(
        stats_no_attack, stats_under_attack, on=["model", "defense", "eval_date", "traceuid"], how="outer"
    )
    overall_results = overall_results.drop(columns=["attack_x"])
    overall_results = overall_results.rename(columns={"attack_y": "attack"})
    overall_results["provider"] = overall_results["model"].apply(lambda m: MODEL_PROVIDERS[m])
    overall_results = overall_results.sort_values(
        ["eval_date", "provider", "model", "attack", "defense"], ascending=False
    )
    overall_results["Utility"] = overall_results["Utility"].apply(lambda x: f"{x * 100:.2f}%")
    overall_results["Utility under attack"] = overall_results["Utility under attack"].apply(lambda x: f"{x * 100:.2f}%")
    overall_results["Targeted ASR"] = overall_results["Targeted ASR"].apply(lambda x: f"{x * 100:.2f}%")
    overall_results = overall_results.rename(
        columns={
            "provider": "Provider",
            "model": "Model",
            "defense": "Defense",
            "attack": "Attack",
            "eval_date": "Date",
            "traceuid": "Traces",
        }
    )

    overall_results = overall_results[
        ["Provider", "Model", "Defense", "Attack", "Utility", "Utility under attack", "Targeted ASR", "Date", "Traces"]
    ]

    # render traceuid s as links
    def traceurl_formatter(traceuid: str) -> str:
        # check for pandas nan
        if traceuid is None or str(traceuid) == "nan":
            return ""
        return f'<a style="font-size: 12pt;" href="https://explorer.invariantlabs.ai/u/{traceuid}/t/1" target="_blank">Explore &rarr;</a>'

    overall_results["Traces"] = overall_results["Traces"].apply(traceurl_formatter)  # type: ignore

    overall_results.to_html(
        "docs/results-table.html",
        index=False,
        na_rep="None",
        classes=["row-border", "order-column", "hover"],
        border=0,
        justify="left",
        table_id="results-table",
        escape=False,
    )


if __name__ == "__main__":
    main()
