import contextvars
import json
import logging
import os
import textwrap
import time
import warnings

from pydantic import BaseModel
from rich.live import Live
from rich.table import Table

from agentdojo.ast_utils import create_python_function_from_tool_call
from agentdojo.types import ChatMessage

LOGGER_STACK = contextvars.ContextVar("logger_stack", default=[])


class Logger:
    def __enter__(self):
        LOGGER_STACK.get().append(self)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        LOGGER_STACK.get().pop()
        return False

    @staticmethod
    def get():
        loggers = LOGGER_STACK.get()
        if len(loggers) == 0:
            return NullLogger()
        return loggers[-1]


class NullLogger(Logger):
    messages: list[ChatMessage]

    def __enter__(self):
        self.messages = []
        self.logdir = None
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        return False

    def log(self, *args, **kwargs):
        pass

    def log_error(self, *args, **kwargs):
        pass


def assistant(text):
    return f":robot_face: [red]{text}[/red]"


def tool(text):
    return f":wrench: [orange]{text}[/orange]"


def user(text):
    return f":bust_in_silhouette: [green]{text}[/green]"


def system(text):
    return f":book: [blue]{text}[/blue]"


class OutputLogger(Logger):
    def log_error(self, message: str):
        logging.warning(f"[red]API error: {message}[/red]", extra={"markup": True})

    def __init__(self, logdir: str | None, live: Live | None = None):
        self.messages = []
        self.logdir = logdir
        self.live = live
        self.table = Table("Role", "Content", title="Chat log", show_lines=True)
        self.previous_user_task_id = None

    def log(self, messages: list[ChatMessage], **kwargs):
        messages = messages[len(self.messages) :]
        self.messages += messages

        user_task_id = kwargs.get("user_task_id" or "unknown user_task_id")
        injection_task_id = kwargs.get("injection_task_id" or "unknown injection_task_id")
        suite_name = kwargs.get("suite_name" or "unknown suite_name")
        pipeline_name = kwargs.get("pipeline_name" or "pipeline suite_name")

        if user_task_id != self.previous_user_task_id:
            self.table = Table(
                "Role",
                "Content",
                title=f"[{suite_name}][{user_task_id}{injection_task_id}]",
                show_lines=True,
            )
            self.previous_user_task_id = user_task_id

        for message in messages:
            role = message["role"]
            content = str(message["content"]) if message["content"] is not None else ""
            if message["role"] == "assistant":
                tool_calls = message.get("tool_calls") or []
            else:
                tool_calls = None

            if tool_calls is not None and len(tool_calls) > 0:
                tool_calls_content = "\n------------\nTool calls:"
                for tool_call in tool_calls:
                    python_function = create_python_function_from_tool_call(tool_call)
                    tool_calls_content += f"\n   - {python_function}"
            else:
                tool_calls_content = ""
            if role == "user":
                role = user(role)
            elif role == "assistant":
                role = assistant(role)
            elif role == "tool":
                if "error" in message and message["error"] is not None:
                    content = f"[red]{message['error']}[/red]"
                if "tool_call" in message:
                    role = f"{role} [bold]{message['tool_call'].function}[/bold]"
                role = tool(role)
            else:
                role = system(role)
            if self.live is None:
                logging_content = textwrap.shorten(f"{content}{tool_calls_content}", width=100, placeholder="[...]")
                if injection_task_id is not None:
                    current_task_id = rf"\[{user_task_id}]\[{injection_task_id}]"
                else:
                    current_task_id = rf"\[{user_task_id}]"
                logging.info(
                    rf"\[{pipeline_name}]\[{suite_name}]{current_task_id} {role}: {logging_content}",
                    extra={"markup": True},
                )
            else:
                self.table.add_row(role, content + tool_calls_content)
                self.live.update(self.table)


class NotebookLogger(OutputLogger):
    def log(
        self,
        messages: list[ChatMessage],
        **kwargs,
    ):
        from IPython.display import clear_output

        clear_output(wait=True)
        super().log(messages, **kwargs)


class TraceLogger(Logger):
    """
    Logs agent interactions as JSON files in a directory.

    Organizes log files by metadata (self.set_context) and logs
    """

    def __init__(self, delegate: NullLogger | OutputLogger, **contextargs):
        import os
        import time
        from importlib.metadata import PackageNotFoundError, version

        self.delegate = delegate
        self.dirpath = delegate.logdir or os.path.join(os.path.dirname(__file__), "../../", "runs")
        self.context = contextargs
        self.messages: list[ChatMessage] = []
        self.delegate.messages = []
        self.error: str | None = None
        self.start = None

        # Add timestamp of evaluation
        self.context["evaluation_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")

        # Add agentdojo package version
        try:
            agentdojo_package_version = version("agentdojo")
            self.context["agentdojo_package_version"] = agentdojo_package_version
        except PackageNotFoundError:
            self.context["agentdojo_package_version"] = "unknown"

        if not os.path.exists(self.dirpath):
            os.makedirs(self.dirpath, exist_ok=True)

    def __enter__(self):
        self.start = time.time()
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        if self.start is not None:
            self.set_contextarg("duration", time.time() - self.start)
        else:
            self.set_contextarg("duration", None)

        return super().__exit__(exc_type, exc_value, traceback)

    def set_contextarg(self, key, value):
        self.context[key] = value
        self.save()

    def save(self):
        # Fetches suite_name, user_task_id=user_task.ID, injection_task_id=injection_task_id, injections=task_injections from context
        suite_name = self.context.get("suite_name", "unknown_suite_name")
        user_task_id = self.context.get("user_task_id", "unknown_user_task_id")
        injection_task_id = self.context.get("injection_task_id", "unknown_injection_task_id")
        injections = self.context.get("injections", {})
        attack_type = self.context.get("attack_type", None)
        pipeline_name = self.context.get("pipeline_name", None)
        other_context = {
            k: v
            for k, v in self.context.items()
            if k
            not in [
                "suite_name",
                "user_task_id",
                "injection_task_id",
                "injections",
                "attack_type",
                "pipeline_name",
            ]
        }

        if attack_type is None:
            warnings.warn(
                f"Cannot log executed benchmark (suite_name={suite_name}, user_task_id={user_task_id}, injection_task_id={injection_task_id}, attack_type={attack_type}) because attack_type is not known. Make sure to run the benchmark with the attack_type parameter set (cf. benchmark_suite_with_injections(...))."
            )
            return

        if pipeline_name is None:
            warnings.warn(
                f"Cannot log executed benchmark (suite_name={suite_name}, user_task_id={user_task_id}, injection_task_id={injection_task_id}, attack_type={attack_type}) because pipeline_name is not known. Make sure to run the benchmark with the pipeline_name parameter set (cf. benchmark_suite_with_injections(...))."
            )
            return
        pipeline_name = pipeline_name.replace("/", "_")

        directory = f"{self.dirpath}/{pipeline_name}/{suite_name}/{user_task_id}/{attack_type}"
        if not os.path.exists(directory):
            os.makedirs(directory, exist_ok=True)
        file = f"{directory}/{injection_task_id or 'none'}.json"
        with open(file, "w") as f:
            f.write(
                json.dumps(
                    {
                        "suite_name": suite_name,
                        "pipeline_name": pipeline_name,
                        "user_task_id": user_task_id,
                        "injection_task_id": injection_task_id,
                        "attack_type": attack_type if attack_type != "none" else None,
                        "injections": injections,
                        "messages": self.messages,
                        "error": self.error,
                        **other_context,
                    },
                    indent=4,
                    default=lambda o: o.model_dump() if isinstance(o, BaseModel) else o,
                )
            )

    def log_error(self, error: str):
        self.error = error
        self.delegate.log_error(error)
        self.save()

    def log(self, messages: list[ChatMessage]):
        user_task_id = self.context.get("user_task_id", "unknown_user_task_id")
        injection_task_id = self.context.get("injection_task_id", "unknown_user_task_id")
        suite_name = self.context.get("suite_name", "unknown_suite_name")
        pipeline_name = self.context.get("pipeline_name", "pipeline_suite_name")
        self.delegate.log(
            messages,
            user_task_id=user_task_id,
            injection_task_id=injection_task_id,
            suite_name=suite_name,
            pipeline_name=pipeline_name,
        )
        self.messages = messages
        self.save()
