import subprocess
import time
from typing import List, Optional, Dict

import openai
import portpicker
import requests

from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.formatter import OpenaiFormatter
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent


class VllmMessenger(LLMMessenger):
    def __init__(
        self,
        model_name: str,
        temperature: float = 1.0,
        max_tokens: int = 2048,
        max_output_tokens: int = 1536,
        log_directory: str = "",
        custom_args: List[str] = [],
    ):
        assert max_output_tokens < max_tokens
        super().__init__(model_name, temperature, log_directory)
        api_key = "NOT-USED"
        port = portpicker.pick_unused_port()
        base_url = f"http://localhost:{port}"
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.max_output_tokens = max_output_tokens
        self.process = popen_launch_server(
            model_name,
            base_url,
            timeout=600,
            api_key=api_key,
            other_args=(
                "--port",
                str(port),
                "--max-model-len",
                str(max_tokens),
                "--trust-remote-code",
                # classify_picked_images_to_sides uses 4 images (2 per side + 2 test instances)
                "--limit-mm-per-prompt",
                "image=4",
                *custom_args,
            ),
        )
        self.client = openai.Client(base_url=f"{base_url}/v1", api_key=api_key)
        self.formatter = OpenaiFormatter()
        self._context: Optional[List[Dict]] = None  # set type explicitly

    def __update_context(self, contents: List[Content], model_response: str):
        if self._context is not None:
            if not self._keep_image_history:
                contents = [
                    content
                    for content in contents
                    if not isinstance(content, ImageContent)
                ]
            self._context += [self.formatter.user(contents)]
            self._context += [self.formatter.assistant(model_response)]

    def ask(self, contents: List[Content]) -> str:
        if self._context is None:
            self.log("INFO", [TextContent("Opening new context")])

        self.log("USER", contents)

        context = self.get_context()
        message = self.formatter.user(contents)
        messages = context + [message]

        response = self.client.chat.completions.create(
            model=self.get_name(),
            messages=messages,
            temperature=self.temperature,
            max_tokens=self.max_output_tokens,
        )
        model_response = response.choices[0].message.content
        print(model_response)

        self.log("ASSISTANT", [TextContent(model_response)])
        self.__update_context(contents, model_response)
        return model_response


def popen_launch_server(
    model: str,
    base_url: str,
    timeout: float,
    api_key: str,
    other_args: tuple = (),
    env: Optional[dict] = None,
    return_stdout_stderr: bool = False,
):
    command = ["vllm", "serve", model, "--api-key", api_key, *other_args]
    if return_stdout_stderr:
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=env,
            text=True,
        )
    else:
        process = subprocess.Popen(command, stdout=None, stderr=None, env=env)

    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            headers = {
                "Content-Type": "application/json; charset=utf-8",
                "Authorization": f"Bearer {api_key}",
            }
            response = requests.get(f"{base_url}/v1/models", headers=headers)
            if response.status_code == 200:
                return process
        except requests.RequestException:
            pass
        time.sleep(10)
    raise TimeoutError("Server failed to start within the timeout period.")
