from typing import Dict, List, Union
from .base_model import BaseModel


class VLLMError(Exception):
    def __init__(self, status_code, message):
        self.status_code = status_code
        self.message = message
        super().__init__(self.message)


try:
    from vllm import LLM, SamplingParams
except ImportError as e:
    raise VLLMError(
        status_code=1,
        message="Failed to import 'vllm' package. Make sure it is installed correctly.",
    ) from e


class LocalVLLM(BaseModel):
    def __init__(
        self,
        model_path: str,
        **vllm_kwargs,
    ) -> None:
        super().__init__(model_path=model_path)

        self.model = LLM(
            model=model_path,
            **vllm_kwargs,
        )

    def validate_vllm(self):
        return True

    def completions(self, messages: List[Dict[str, str]], **kwargs):
        for message in messages:
            if not isinstance(message, list):
                assert 0, "Each message must be provided as a list"
            for msg in message:
                if not isinstance(msg, dict):
                    assert 0, "Each message must be provided as a dictionary"
                if "role" not in msg:
                    assert 0, "Each message must contain 'role' key"
                if "content" not in msg:
                    assert 0, "Each message must contain 'content' key"

        params = SamplingParams(**kwargs)
        
        outputs = self.model.chat(messages=messages, sampling_params=params, use_tqdm=True)
        #print(outputs[0].prompt)
        #print(outputs)
        #outputs = self.model.generate(prompts, params, use_tqdm=use_tqdm)
        outputs = [output.outputs[0].text for output in outputs]
        return outputs
    
    def generate(self, messages: List[Dict[str, str]], **kwargs):
        # for message in messages:
        #     if not isinstance(message, list):
        #         assert 0, "Each message must be provided as a list"
        #     for msg in message:
        #         if not isinstance(msg, dict):
        #             assert 0, "Each message must be provided as a dictionary"
        #         if "role" not in msg:
        #             assert 0, "Each message must contain 'role' key"
        #         if "content" not in msg:
        #             assert 0, "Each message must contain 'content' key"

        sampling_params = SamplingParams(**kwargs) #.pop("sampling_params", {}))
        # print(sampling_params)
        # SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.03, temperature=1.0, top_p=0.9, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=4096, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None)
        #outputs = self.model.chat(messages=messages, sampling_params=sampling_params, use_tqdm=True)
        outputs = self.model.generate(
            messages=messages,
            sampling_params=sampling_params,
            use_tqdm=True
        )
        #assert len(outputs) == 1
        results = {}
        for output in outputs:
            prompt = output.prompt
            input_token_ids = output.prompt_token_ids
            generation = output.outputs[0].text
            generation_ids = output.outputs[0].token_ids

            results['prompt_text'] = prompt
            results['prompt_text_token_len'] = len(input_token_ids)
            results['output'] = generation.strip()
            results['output_token_len'] = len(generation_ids)

        return results