from typing import List

try:
    import torch
    from vllm import LLM, SamplingParams
except ImportError:
    pass

from ..message import Message
from .base import BaseModel, build_prompt


ENGINE = None


class Response:
    def __init__(self, response):
        self.response = response

    def __iter__(self):
        return iter([self.response])


async def consume_gen(async_gen):
    final = None
    async for item in async_gen:
        final = item
    return final


class VLLMModel(BaseModel):
    """Model builder for vllm models. `model` should be in the format model_name@checkpoint_path

    Call with a list of `Message` objects to generate a response.
    """

    supports_system_message = True

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        stream: bool = False,
        top_p: float = 1.0,
        max_tokens: int = 512,
        stop=None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        **kwargs,
    ):
        model_name, checkpoint_path = model.split("@")
        self.model_name = model_name
        self.checkpoint_path = checkpoint_path
        self.temperature = temperature
        self.stream = stream
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.stop = stop
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty

    def __call__(self, messages: List[Message], api_key: str = None):
        global ENGINE
        if ENGINE is None:
            ENGINE = LLM(
                self.checkpoint_path, tensor_parallel_size=torch.cuda.device_count()
            )

        sampling_params = SamplingParams(
            use_beam_search=False,
            temperature=self.temperature,
            top_p=self.top_p,
            n=1,
            max_tokens=self.max_tokens,
            stop=self.stop,
            frequency_penalty=self.frequency_penalty,
            presence_penalty=self.presence_penalty,
        )
        prompt = build_prompt(messages, self.model_name)
        generation = ENGINE.generate(prompt, sampling_params, use_tqdm=False)[0]
        response = generation.outputs[0].text.strip()
        return Response(response)
