import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from redacted import (
    ChatModel,
    ChatRequest,
    ChatResponse,
    ChatResponseMessage,
    Function,
    Message,
    Role,
    Usage,
)
from pydantic import BaseModel, Field
from rich.console import Console
from rich.text import Text

Conversation = list[Union[Message, ChatResponseMessage]]


class Environment:
    def __init__(self, root: Path = None):
        self.root = root
        self.conversation: Conversation = list()

    @property
    def iterations(self) -> int:
        return len([m for m in self.conversation if isinstance(m, ChatResponseMessage)])


class ToolResult(BaseModel):
    values: dict[str, Any] = Field(default_factory=dict)
    files: dict[str, Path] = Field(default_factory=dict)


T = TypeVar("T", bound="Environment")


class Tool(ABC, Generic[T]):
    @abstractmethod
    def definition(self) -> Function:
        pass

    @abstractmethod
    def execute(
        self,
        arguments: dict[str, Any],
        environment: T,
        identifier: Optional[str] = None,
    ) -> ToolResult:
        pass

    @abstractmethod
    def compile(self, result: ToolResult) -> str:
        pass


class ToolCall(BaseModel):
    identifier: str
    tool: str
    arguments: dict[str, Any]
    result: ToolResult


class AgentResult(BaseModel):
    message: Optional[str] = None
    tools: list[list[ToolCall]] = Field(default_factory=list)
    usage: list[Usage] = Field(default_factory=list)


class Agent(Generic[T]):
    def __init__(
        self,
        model: ChatModel,
        system: Optional[str] = None,
        tools: Optional[list[Tool[T]]] = None,
        tool_callback: Optional[Callable[[ToolCall, T], None]] = None,
        request: Optional[ChatRequest] = None,
        iterations: int = 8,
        console: Optional[Console] = None,
    ):
        self.model = model
        self.system = system
        self.tools = tools
        self.tool_map = {tool.definition().name: tool for tool in self.tools}
        self.tool_callback = tool_callback
        self.request = request or ChatRequest()
        self.request.tools = [tool.definition() for tool in self.tools]
        self.iterations = iterations
        self.console = console or Console(quiet=True)

    def run(self, utterance: str, environment: T) -> AgentResult:
        environment.conversation = list()
        if self.system:
            environment.conversation.append(
                Message(
                    role=Role.System,
                    content=self.system,
                )
            )
            _render_message(environment.conversation[-1], self.console)
        if utterance:
            environment.conversation.append(
                Message(
                    role=Role.User,
                    content=utterance,
                )
            )
        return self.resume(environment)

    def resume(self, environment: T) -> AgentResult:
        iteration = environment.iterations
        result = AgentResult()
        _render_message(environment.conversation[-1], self.console)
        while iteration < self.iterations:
            response: ChatResponse = self.model.chat(
                messages=environment.conversation, request=self.request
            )
            response_message = response.choices[0].message
            result.usage.append(response.usage)
            environment.conversation.append(response_message)
            _render_message(response_message, self.console)
            if len(tool_calls := (response_message.tool_calls or list())) > 0:
                result.tools.append(list())
                for tool_call in tool_calls:
                    if tool_call.function.name not in self.tool_map:
                        raise ValueError(
                            f"Tool {tool_call.function.name} not found in agent tools."
                        )
                    tool = self.tool_map[tool_call.function.name]
                    tool_id = tool_call.id
                    tool_arguments = json.loads(tool_call.function.arguments)
                    tool_raw = tool.execute(
                        environment=environment,
                        arguments=tool_arguments,
                        identifier=tool_id,
                    )
                    tool_response = tool.compile(tool_raw)
                    tool_call_obj = ToolCall(
                        identifier=tool_id,
                        tool=tool_call.function.name,
                        arguments=tool_arguments,
                        result=tool_raw,
                    )
                    environment.conversation.append(
                        Message(
                            role=Role.Tool,
                            content=tool_response,
                            tool_call_id=tool_id,
                        )
                    )
                    result.tools[-1].append(tool_call_obj)
                    if self.tool_callback:
                        self.tool_callback(tool_call_obj, environment)
                    _render_tool(
                        tool_call.function.name,
                        tool_arguments,
                        tool_response,
                        self.console,
                    )
            if response.choices[0].finish_reason == "stop":
                if response_message.content:
                    result.message = response_message.content
                break
            iteration += 1
        return result


def _render_message(message: Message, console: Console) -> None:
    if message.content is None:
        return
    role = {
        Role.User: "👤",
        Role.Assistant: "🤖",
        Role.System: "⚙️",
        Role.Tool: "⛏️",
    }
    console.print(f"> {role[message.role]}  {message.role.value.capitalize()}:")
    console.print(
        Text(_indent(_wrap(message.content, width=180), spaces=2)),
        highlight=False,
    )


def _render_tool(name: str, arguments: dict, response: str, console: Console) -> None:
    tool_arguments = ", ".join(f"{key}={value}" for key, value in arguments.items())
    console.print(
        Text(f"  ⛏️  {name}({tool_arguments})", style="gray50"), highlight=False
    )
    console.print(
        Text(_indent(_wrap(response, width=180), spaces=5), style="gray50"),
        highlight=False,
    )


def _indent(s: str, spaces: int = 2) -> str:
    return "\n".join(" " * spaces + line for line in s.splitlines())


def _wrap(s: str, width: int = 180) -> str:
    return "\n".join(
        line[i : i + width]
        for line in s.splitlines()
        for i in range(0, len(line), width)
    )
