import abc
from collections import defaultdict
from functools import lru_cache
from typing import Generic, NewType

from rich.progress import track

from agentgym.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentgym.agent_pipeline.ground_truth_pipeline import GroundTruthPipeline
from agentgym.base_tasks import BaseInjectionTask, BaseUserTask
from agentgym.functions_engine.functions_engine import FUNCTIONS_DOCS
from agentgym.task_suites.task_suite import (
    Env,
    TaskSuite,
    get_injection_vector_defaults,
)

# user_task_id -> injection_task_id -> vector_id -> value
Injections = NewType("Injections", dict[str, dict[str, dict[str, str]]])

DEFAULT_MODEL_NAME = "the AI language model"
DEFAULT_USER_NAME = "the user"


class BaseAttacker(abc.ABC, Generic[Env]):
    is_dos_attack: bool = False

    def __init__(
        self,
        name: str,
        task_suite: TaskSuite[Env],
        target_pipeline: BasePipelineElement | None = None,
        user_name: str = DEFAULT_USER_NAME,
        model_name: str = DEFAULT_MODEL_NAME,
    ) -> None:
        self.name = name
        self.task_suite = task_suite
        self.target_pipeline = target_pipeline
        self.user_name = user_name
        self.model_name = model_name
        injection_defaults = get_injection_vector_defaults(task_suite.name)
        self.canary_injections = {inj: f"<CANARY---{inj}--->" for inj in injection_defaults}

    @lru_cache
    def get_injection_candidates(self, user_task_id: str) -> list[str]:
        user_task = self.task_suite.user_tasks[user_task_id]
        gt_pipeline = GroundTruthPipeline(user_task)
        environment = self.task_suite.load_and_inject_default_environment(self.canary_injections)
        tools = [FUNCTIONS_DOCS[tool] for tool in self.task_suite.tools]
        _, _, _, responses, _ = gt_pipeline.query(user_task.PROMPT, tools, environment)

        injection_candidates = []
        for inj in self.canary_injections:
            found = False
            for response in responses:
                if response["content"] is not None and self.canary_injections[inj] in response["content"]:
                    found = True
            if found:
                injection_candidates.append(inj)
        if len(injection_candidates) == 0:
            raise ValueError(f"{user_task.ID} is not injectable.")
        return injection_candidates

    def generate_injections(self) -> Injections:
        is_suite_valid, _ = self.task_suite.check()
        if not is_suite_valid:
            raise ValueError(f"Suite {self.task_suite.name} failed checks. Fix the suite first.")

        injections = defaultdict(dict)
        for user_task_id, user_task in track(self.task_suite.user_tasks.items()):
            for inj_task_id, inj_task in self.task_suite.injection_tasks.items():
                attack = self.attack(user_task, inj_task)
                injections[user_task_id][inj_task_id] = attack

        return Injections(dict(injections))

    def attack(
        self,
        user_task: BaseUserTask[Env],
        injection_task: BaseInjectionTask[Env],
    ) -> dict[str, str]: ...
