from dataclasses import dataclass
from typing import List, Optional, Union

from transformers import AutoTokenizer

from reason.inference.text_generation import ConcatedLMGenResult, _generate_vllm


@dataclass
class LMCallingConfig:
    n: int = 1
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1  # -1 for vllm by default
    max_new_tokens: int = 1024
    stop_token_ids: Optional[List[int]] = None
    stop_str: Optional[Union[str, List[str]]] = None
    include_stop_str_in_output: bool = False
    first_generation: bool = False


class LanguageModelCallingFunction:

    def __init__(self, llm_step_tag: str = None):
        self.llm_step_tag = llm_step_tag

    def __call__(self, messages: List, config: LMCallingConfig) -> ConcatedLMGenResult:
        raise NotImplementedError


class VLLMRemoteCaller(LanguageModelCallingFunction):
    def __init__(
        self,
        model_name,
        model_path,
        controller_addr: str = "http://localhost:21001",
        llm_step_tag: str = None,
        apply_chat_template: bool = False,
        multi_gpu: bool = True,
        serve_type: str = "vllm",
        double_line_break: int = 0,
        model_idx: int = 0,  # model index for multi-model serving
    ):
        self.model_name = model_name
        self.model_path = model_path
        self.controller_addr = controller_addr
        self.model_idx = model_idx  
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.apply_chat_template = apply_chat_template
        self.multi_gpu = multi_gpu
        self.serve_type = serve_type
        self.double_line_break = double_line_break
        super().__init__(llm_step_tag)

    def __call__(self, messages: str, config: LMCallingConfig) -> ConcatedLMGenResult:
        if self.serve_type == "vllm":
            return _generate_vllm(
                messages=messages,
                model_name=self.model_name,
                n=config.n,
                temperature=config.temperature,
                top_p=config.top_p,
                top_k=config.top_k,
                max_new_tokens=config.max_new_tokens,
                stop_token_ids=config.stop_token_ids,
                stop_str=config.stop_str,
                controller_addr=self.controller_addr,
                include_stop_str_in_output=config.include_stop_str_in_output,
                tokenizer=self.tokenizer,
                apply_chat_template=self.apply_chat_template,
                multi_gpu=self.multi_gpu,
                double_line_break=self.double_line_break,
                first_generation=config.first_generation,
                model_idx=self.model_idx,  
            )
        
