import ast
import asyncio
import functools
import json
import os
import re
from copy import deepcopy
from dataclasses import dataclass
from queue import Empty, Queue
from threading import Thread
from typing import Any, cast
from logging import getLogger

from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
from inspect_ai.model import ModelAPI, modelapi, GenerateConfig, ChatMessage, ModelOutput, \
    ChatCompletionChoice, ChatMessageAssistant, ModelUsage, ChatMessageUser, ChatMessageTool, \
    ChatMessageSystem
from inspect_ai.model._model_call import ModelCall
from inspect_ai.model._providers.anthropic import AnthropicAPI
# from inspect_ai.model._model_output import ModelCall
from inspect_ai.model._providers.openai import OpenAIAPI
from inspect_ai.model._providers.util import chat_api_input
from inspect_ai.tool import ToolInfo, ToolChoice, ToolCall, ToolFunction
from inspect_ai.model._model import simple_input_messages
from typing_extensions import override
from vllm import LLM, SamplingParams

from .soft_json_parser import soft_parse_json

TOOL_MESSAGE = """
# Tool Use

You have access to the following tools:

{tool_description_with_args}

The way you use the tools is by specifying a json blob (surrounded by "<tool_call>" and "</tool_call>") with `action` and `action_input` keys. The `action` key is name of the tool to use and the `action_input` key are the tool's arguments.

The only values that should be in the `action` field are: {tool_names}

Here is an example of a valid tool use:

<tool_call>
```
{{
  "action": "tool_name",
  "action_input": {{
    "arg1": "value1",
    "arg2": "value2"
  }}
}}
```
</tool_call>

Notice how it is surrounded by both <tool_call> tags and tick marks. Make sure to have the `action_input` in the right format for the tool you are using, and do not put variable names as input if you can find the right values. In particular, make sure to escape any quotes within strings.

"""

logger = getLogger(__name__)

TOOL_USE_MARKER = '</tool_call>'
DEFAULT_START_TOKEN = '<|im_start|>'
DEFAULT_END_TOKEN = '<|im_end|>'


@modelapi(name="vllm2")
class VLLMAPI(ModelAPI):
    def __init__(
            self,
            model_name: str,
            base_url: str | None = None,
            config: GenerateConfig = GenerateConfig(),
            **model_args: Any,
    ) -> None:

        super().__init__(model_name, base_url, config)

        self.seed = None
        if config.seed is not None:
            self.seed = config.seed

        # collect known model_args (then delete them so we can pass the rest on)
        def collect_model_arg(name: str) -> Any | None:
            nonlocal model_args
            value = model_args.get(name, None)
            if value:
                model_args.pop(name)
            return value

        device = collect_model_arg("device")
        tokenizer = collect_model_arg("tokenizer")
        model_path = collect_model_arg("model_path")
        tokenizer_path = collect_model_arg("tokenizer_path")
        self.batch_size = collect_model_arg("batch_size")
        self.chat_template = collect_model_arg("chat_template")

        # get devices and use os.environ to set CUDA_VISIBLE_DEVICES
        if device is not None:
            if isinstance(device, list):
                device = ",".join(map(str, device))
            os.environ["CUDA_VISIBLE_DEVICES"] = str(device)

        if model_path:
            model_name = model_path

        # https://github.com/vllm-project/vllm/pull/6051
        # Gemma 2 models require FlashInfer backend for softcap logits
        if 'google/gemma-2' in model_name:
            os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'

        # load tokenizer
        if not tokenizer:
            if tokenizer_path:
                tokenizer = tokenizer_path
            else:
                tokenizer = model_name

        # self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        # # LLMs generally don't have a pad token and we need one for batching
        # self.tokenizer.pad_token = self.tokenizer.eos_token
        # self.tokenizer.padding_side = "left"

        if 'tensor_parallel_size' not in model_args:
            devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
            num_devices = len(devices)
            if num_devices > 1:
                model_args["tensor_parallel_size"] = num_devices
            else:
                model_args["tensor_parallel_size"] = 1

        model_args.pop("api_key", None)

        # load model
        self.model = LLM(
            model_name,
            tokenizer=tokenizer,
            **model_args
        )

        self.tokenizer = self.model.get_tokenizer()

    def apply_chat_template(self, messages: list[ChatMessage]) -> str:
        # handle system message and consecutive user messages
        messages = simple_input_messages(messages)
        # convert to chat template input format
        chat_messages = chat_api_input(messages)
        # apply chat template
        chat = self.tokenizer.apply_chat_template(
            chat_messages,
            add_generation_prompt=True,
            tokenize=False,
            chat_template=self.chat_template,
        )
        # return
        return cast(str, chat)

    @override
    def max_connections(self) -> int:
        """Effectively the batch size."""
        return 32

    def get_sampling_params(self, config: GenerateConfig, chat: str) -> SamplingParams:
        kwargs: dict[str, Any] = dict()
        if config.max_tokens is not None:
            kwargs["max_tokens"] = config.max_tokens
        else:
            kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
        if config.temperature is not None:
            kwargs["temperature"] = config.temperature
        if config.top_p is not None:
            kwargs["top_p"] = config.top_p
        if config.top_k is not None:
            kwargs["top_k"] = config.top_k
        # if config.min_p is not None:
        #     kwargs["min_p"] = config.min_p
        if config.seed is not None:
            kwargs["seed"] = config.seed
        elif self.seed is not None:
            kwargs["seed"] = self.seed
        # if config.logprobs is not None:
        #     kwargs["output_logits"] = True
        # if "return_dict_in_generate" in kwargs:
        #     assert kwargs["return_dict_in_generate"]
        # kwargs["return_dict_in_generate"] = True

        if chat.startswith(DEFAULT_START_TOKEN):
            sampling_params = SamplingParams(
                **kwargs,
                stop=[DEFAULT_END_TOKEN, TOOL_USE_MARKER],
                include_stop_str_in_output=True
            )
        else:
            sampling_params = SamplingParams(
                **kwargs,
                stop_token_ids=self.tokenizer.all_special_ids,
                stop=[TOOL_USE_MARKER],
                include_stop_str_in_output=True
            )
        return sampling_params

    @staticmethod
    def add_tool_prompt(message: ChatMessage, tools: list[ToolInfo]):
        # TODO: implement tool_choice
        if tools is not None and len(tools) > 0 and '# Tool Use' not in message.text:
            # tools_descriptions = [chat_api_tool(tool)["function"] for tool in tools]
            tools_descriptions = [tool.model_dump(exclude_none=True) for tool in tools]
            tool_names = [tool.name for tool in tools]

            assert isinstance(message, (ChatMessageUser, ChatMessageSystem))

            tool_message = TOOL_MESSAGE.format(tool_description_with_args=tools_descriptions, tool_names=tool_names)
            message.text = tool_message + message.text

    async def generate(
            self,
            input: list[ChatMessage],
            tools: list[ToolInfo] = None,
            tool_choice: ToolChoice = None,
            config: GenerateConfig = None,
    ) -> ModelOutput:

        input = [convert_tool_to_user_message(m) for m in input]
        self.add_tool_prompt(input[0], tools)
        chat = self.apply_chat_template(input)

        # print(chat)

        if config is None:
            config = GenerateConfig()

        # prepare generator
        sampling_params = self.get_sampling_params(config, chat)
        generator = functools.partial(self.model.generate, sampling_params=sampling_params, use_tqdm=False)

        # generate
        response = await batched_generate(
            GenerateInput(
                input=chat,
                generator=generator,
                batch_size=config.max_connections or self.max_connections(),
            )
        )

        tool_call = extract_tool_call(response)
        if tool_call is not None:
            tool_call = [tool_call]

        choice = ChatCompletionChoice(
            message=ChatMessageAssistant(
                content=response.output,
                tool_calls=tool_call,
                source="generate",
            ),
            # logprobs=(
            #     Logprobs(content=final_logprobs) if final_logprobs is not None else None
            # ),
        )

        input_tokens = len(self.tokenizer(chat)['input_ids'])
        output_tokens = len(self.tokenizer(response.output)['input_ids'])
        total_tokens = input_tokens + output_tokens

        # return output
        return ModelOutput(
            model=self.model_name,
            choices=[choice],
            usage=ModelUsage(
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                total_tokens=total_tokens,
            ),
        )


TOOL_MESSAGE_SERVER = """
In order to use bash commands, you need to surround them with the following tags: <tool_call> and </tool_call> and begin the command with a '$' like so

<tool_call>
$ cat example.txt
</tool_call>

You can only write one command at a time (i.e. only one '$' within the tags). Make sure the bash command is formated correctly and strings are properly escaped.

"""


@modelapi(name="vllm3")
class VLLMServerAPI(OpenAIAPI):
    def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
        params = super().completion_params(config, tools)
        params["extra_body"] = dict(include_stop_str_in_output=True)
        return params

    @staticmethod
    def add_tool_prompt(message: ChatMessage, tools: list[ToolInfo]):
        # TODO: implement tool_choice
        if tools is not None and len(tools) > 0 and '# Tool Use' not in message.text:
            # tools_descriptions = [chat_api_tool(tool)["function"] for tool in tools]
            # tools_descriptions = [tool.model_dump(exclude_none=True) for tool in tools]
            # tool_names = [tool.name for tool in tools]

            assert isinstance(message, (ChatMessageUser, ChatMessageSystem))

            tool_message = TOOL_MESSAGE_SERVER

            # TOOL_MESSAGE.format(tool_description_with_args=tools_descriptions, tool_names=tool_names)
            message.text = tool_message + message.text

    async def generate(
            self,
            input: list[ChatMessage],
            tools: list[ToolInfo],
            tool_choice: ToolChoice,
            config: GenerateConfig,
    ) -> ModelOutput | tuple[ModelOutput, ModelCall]:

        # print(",".join([str(message.model_dump()) for message in input]))
        self.add_tool_prompt(input[0], tools)

        if isinstance(input[0], ChatMessageSystem):
            # combine with next user message
            input = deepcopy(input)
            input[1].content = input[0].content + input[1].content
            input.pop(0)

        input = [convert_tool_to_user_message(m) for m in input]

        config.stop_seqs = [TOOL_USE_MARKER]
        output = await super().generate(input, [], "none", config)

        model_call = None
        if isinstance(output, tuple):
            output, model_call = output

        choices = output.choices

        if len(choices) > 0:
            choice = choices[0]
            message = choice.message
            tool_call = extract_tool_call_server(message.content)
            if tool_call is not None:
                tool_call = [tool_call]
            choice.message = ChatMessageAssistant(
                content=message.content,
                tool_calls=tool_call,
                source=message.source,
            )

        if model_call is not None:
            return output, model_call
        else:
            return output


def convert_tool_to_user_message(message: ChatMessage):
    if isinstance(message, ChatMessageTool):
        content = message.error.message if message.error is not None else message.text
        return ChatMessageUser(
            source=message.source,
            content='Command Output:\n' + content
        )
    elif isinstance(message, ChatMessageAssistant):
        # filter out tool calls for anthropic api
        return ChatMessageAssistant(
            source=message.source,
            content=message.content
        )
    else:
        return message


@modelapi(name="anthropic2")
class Anthropic2API(AnthropicAPI):
    # def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
    #     params = super().completion_params(config, tools)
    #     params["extra_body"] = dict(include_stop_str_in_output=True)
    #     return params

    @staticmethod
    def add_tool_prompt(message: ChatMessage, tools: list[ToolInfo]):
        # TODO: implement tool_choice
        if tools is not None and len(tools) > 0 and '# Tool Use' not in message.text:
            # tools_descriptions = [chat_api_tool(tool)["function"] for tool in tools]
            # tools_descriptions = [tool.model_dump(exclude_none=True) for tool in tools]
            # tool_names = [tool.name for tool in tools]

            assert isinstance(message, (ChatMessageUser, ChatMessageSystem))

            tool_message = TOOL_MESSAGE_SERVER

            # TOOL_MESSAGE.format(tool_description_with_args=tools_descriptions, tool_names=tool_names)
            message.text = tool_message + message.text

    async def generate(
            self,
            input: list[ChatMessage],
            tools: list[ToolInfo],
            tool_choice: ToolChoice,
            config: GenerateConfig,
    ) -> ModelOutput | tuple[ModelOutput, ModelCall]:

        # print(",".join([str(message.model_dump()) for message in input]))
        self.add_tool_prompt(input[0], tools)

        # if isinstance(input[0], ChatMessageSystem):
        #     # combine with next user message
        #     input = deepcopy(input)
        #     input[1].content = input[0].content + input[1].content
        #     input.pop(0)

        input = [convert_tool_to_user_message(m) for m in input]
        PREFIX = "<tool_call>\n$"
        input += [ChatMessageAssistant(content=PREFIX)]

        config.stop_seqs = [TOOL_USE_MARKER]
        output = await super().generate(input, [], "none", config)

        model_call = None
        if isinstance(output, tuple):
            output, model_call = output

        choices = output.choices

        if len(choices) > 0:
            choice = choices[0]
            message = choice.message
            message.text = PREFIX + message.text + TOOL_USE_MARKER
            tool_call = extract_tool_call_server(message.text)
            if tool_call is not None:
                tool_call = [tool_call]
            choice.message = ChatMessageAssistant(
                content=message.content,
                tool_calls=tool_call,
                source=message.source,
            )

        if model_call is not None:
            return output, model_call
        else:
            return output


@dataclass
class GenerateInput:
    input: str
    generator: Any
    batch_size: int


@dataclass
class GenerateOutput:
    output: str
    # input_tokens: int
    # output_tokens: int
    # total_tokens: int
    # logprobs: torch.Tensor | None


@dataclass
class _QueueItem:
    input: GenerateInput
    future: asyncio.Future[GenerateOutput]
    loop: asyncio.AbstractEventLoop


batch_thread: Thread | None = None

batch_queue: "Queue[_QueueItem]" = Queue()


async def batched_generate(input: GenerateInput) -> GenerateOutput:
    # start the background thread if necessary
    global batch_thread
    if batch_thread is None:
        batch_thread = Thread(target=process_batches, daemon=True)
        batch_thread.start()

    # enqueue the job
    loop = asyncio.get_event_loop()
    future: asyncio.Future[GenerateOutput] = loop.create_future()
    batch_queue.put(_QueueItem(input=input, future=future, loop=loop))

    # await the job
    await future

    # return it
    return future.result()


def post_process_output(output) -> str:
    output: str = output.outputs[0].text
    if output.endswith(DEFAULT_END_TOKEN):
        output = output[:len(DEFAULT_END_TOKEN)]
    return output


def process_batches() -> None:
    while True:
        # drain the queue (wait until no new messages have shown up for 2 seconds)
        inputs: list[tuple[GenerateInput, asyncio.Future[GenerateOutput]]] = []
        while True:
            try:
                input = batch_queue.get(timeout=2)  # wait 2 seconds max TODO: what's optimal wait time?
                loop = input.loop
                inputs.append((input.input, input.future))
                if len(inputs) == input.input.batch_size:
                    # max batch size reached
                    break
            except Empty:
                # we have exhausted the queue
                break

        # see if we have any work to do
        if len(inputs) == 0:
            continue

        try:
            first_input = inputs[0][0]
            generator = first_input.generator

            # generate
            outputs = generator([input[0].input for input in inputs])

            for i, output in enumerate(outputs):
                future = inputs[i][1]
                output = post_process_output(output)

                # asyncio futures are not thread safe, so we need to pass the event loop
                # down to this point, so we can mark the future as done in a thread safe manner.
                # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
                loop.call_soon_threadsafe(
                    future.set_result,
                    GenerateOutput(
                        output=output,
                        # logprobs=None,
                    ),
                )

        except Exception as e:
            for _, future in inputs:
                loop.call_soon_threadsafe(future.set_exception, e)


def extract_tool_call_server(response: str) -> ToolCall | None:
    # <tool_call>\n$ (the command to capture)\n</tool_call>
    TOOL_USE_PATTERN = re.compile(r"<tool_call>\s*\$\s*(.*?)\s*<\/tool_call>", re.DOTALL)
    match = re.search(TOOL_USE_PATTERN, response)
    if match:
        return ToolCall(
            id=str(id(response)),  # TODO: better id choice?
            function="bash",
            arguments={"cmd": match.group(1)},
            type="function",
            parse_error=None,
        )
    elif "<tool_call>" in response:
        return ToolCall(
            id=str(id(response)),
            function="bash",
            arguments={},
            type="function",
            parse_error=f"Error parsing tool call.\nMake sure it is of the form:\n<tool_call>\n$ cat example.txt\n</tool_call>",
        )
    else:
        return None


TOOL_USE_PATTERN = re.compile(r"<tool_call>(?:\s*```json|\s*```)?(.*?)(?:```\s*)?</tool_call>", re.DOTALL)
TOOL_USE_FORMATTER = '<tool_call>\n```\n{tool_use_json}\n```\n</tool_call>'


def extract_from_xml(response: str) -> str | None:
    # there should only be one instance of "</tool_call>" since we stop generation when we find it
    match = re.search(TOOL_USE_PATTERN, response)
    if match:
        return match.group(1).strip()
    else:
        # If no match found, return an empty dictionary
        return None


def extract_tool_call(response: str) -> ToolCall | None:
    tool_use_json = extract_from_xml(response)
    if tool_use_json is None:
        return None

    # ensure good formating of tool use
    formated_tool_use = TOOL_USE_FORMATTER.format(tool_use_json=tool_use_json)
    try:
        response = re.sub(TOOL_USE_PATTERN, formated_tool_use, response)
    except Exception as ex:
        logger.info(f"Error formating tool use in response: {ex}")

    error: str | None = None
    func_name = "unknown"
    arguments = {}

    try:
        parsed_dict: dict = soft_parse_json(tool_use_json)
        assert isinstance(parsed_dict, dict), "Tool call must be a dict."
        tool_call_dict = parsed_dict

        assert "action" in tool_call_dict, "Tool call must contain \"action\" key."
        assert isinstance(tool_call_dict["action"], str), "Tool call's \"action\" must be a string."
        func_name = tool_call_dict["action"]

        assert "action_input" in tool_call_dict, "Tool call must contain \"action_input\" key."
        assert isinstance(tool_call_dict["action_input"], dict), "Tool call's \"action_input\" must be a dict."
        arguments = tool_call_dict["action_input"]

    except (AssertionError, json.JSONDecodeError) as ex:
        example_tool_use = TOOL_USE_FORMATTER.format(
            tool_use_json=json.dumps({'action': 'tool_name', 'action_input': {'arg1': '...'}}, indent=4)
        )
        error = f"Error parsing the following tool call: {formated_tool_use}\nTool call must be formated like:\n{example_tool_use}\nError details: {ex}"
        logger.info(error)

    return ToolCall(
        id=str(id(response)),  # TODO: better id choice?
        function=func_name,
        arguments=arguments,
        type="function",
        parse_error=error,
    )


def main():
    api = VLLMAPI("google/gemma-2b-it", device='5')
    messages = [
        ChatMessageUser(content="Lorem ipsum dolor"),
    ]
    config = GenerateConfig(
        seed=42,
        temperature=0.7,
    )
    # <|im_end|>
    output = asyncio.run(api.generate(messages, config=config))
    print(output)
    print(output.choices[0].message.content)


if __name__ == "__main__":
    main()
