import subprocess
import time
from typing import List, Optional, Dict

import openai
import requests

from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.formatter import OpenaiFormatter
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent


class SglangMessenger(LLMMessenger):
    def __init__(
        self,
        model_name: str,
        temperature: float = 0.0,
        max_tokens: int = 128,
        log_directory: str = "",
        base_url: str = "http://127.0.0.1:30000",
    ):
        super().__init__(model_name, temperature, log_directory)
        api_key = "NOT-USED"
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.process = popen_launch_server(
            model_name,
            base_url,
            timeout=600,
            api_key=api_key,
            other_args=(
                "--chat-template",
                "chatml-llava",
                "--disable-cuda-graph",
                "--trust-remote-code",
            ),
        )
        self.client = openai.Client(base_url=f"{base_url}/v1", api_key=api_key)
        self.formatter = OpenaiFormatter()
        self._context: Optional[List[Dict]] = None

    def __update_context(self, contents: List[Content], model_response: str):
        if self._context is not None:
            if not self._keep_image_history:
                contents = [
                    content
                    for content in contents
                    if not isinstance(content, ImageContent)
                ]
            self._context += [self.formatter.user(contents)]
            self._context += [self.formatter.assistant(model_response)]

    def ask(self, contents: List[Content]) -> str:
        if self._context is None:
            self.log("INFO", [TextContent("Opening new context")])

        self.log("USER", contents)

        context = self.get_context()
        message = self.formatter.user(contents)
        messages = context + [message]

        response = self.client.chat.completions.create(
            model="default",
            messages=messages,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )
        model_response = response.choices[0].message.content
        print(model_response)

        self.log("ASSISTANT", [TextContent(model_response)])
        self.__update_context(contents, model_response)
        return model_response


def popen_launch_server(
    model: str,
    base_url: str,
    timeout: float,
    api_key: str,
    other_args: tuple = (),
    env: Optional[dict] = None,
    return_stdout_stderr: bool = False,
):
    _, host, port = base_url.split(":")
    host = host[2:]

    command = [
        "python3",
        "-m",
        "sglang.launch_server",
        "--model-path",
        model,
        "--host",
        host,
        "--port",
        port,
        *other_args,
    ]
    if api_key:
        command += ["--api-key", api_key]

    if return_stdout_stderr:
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=env,
            text=True,
        )
    else:
        process = subprocess.Popen(command, stdout=None, stderr=None, env=env)

    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            headers = {
                "Content-Type": "application/json; charset=utf-8",
                "Authorization": f"Bearer {api_key}",
            }
            response = requests.get(f"{base_url}/v1/models", headers=headers)
            if response.status_code == 200:
                return process
        except requests.RequestException:
            pass
        time.sleep(10)
    raise TimeoutError("Server failed to start within the timeout period.")
