from typing import Dict, List, Optional

import numpy as np

from opencompass.models.base import BaseModel
from opencompass.utils import get_logger

try:
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
except ImportError:
    LLM, SamplingParams = None, None

DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True)


class VLLM(BaseModel):
    """Model Wrapper for VLLM."""

    def __init__(
        self,
        path: str,
        max_seq_len: int = 2048,
        model_kwargs: dict = None,
        generation_kwargs: dict = dict(),
        meta_template: Optional[Dict] = None,
        mode: str = 'none',
        use_fastchat_template: bool = False,
        lora_path: str = None,
        stop_words: List[str] = [],
    ):
        super().__init__(path=path,
                         max_seq_len=max_seq_len,
                         meta_template=meta_template)

        assert LLM, ('Please install VLLM with `pip install vllm`. '
                     'note: torch==2.1.2 is required.')
        self.logger = get_logger()
        self._load_model(path, model_kwargs)
        self.tokenizer = self.model.get_tokenizer()
        self.generation_kwargs = generation_kwargs
        self.generation_kwargs.pop('do_sample', None)
        self.lora_path = lora_path
        assert mode in ['none', 'mid']
        self.mode = mode
        self.use_fastchat_template = use_fastchat_template
        self.stop_words = stop_words

    def _load_model(self,
                    path: str,
                    add_model_kwargs: dict = None,
                    num_retry: int = 3):
        model_kwargs = DEFAULT_MODEL_KWARGS.copy()
        if add_model_kwargs is not None:
            model_kwargs.update(add_model_kwargs)
        import ray

        if ray.is_initialized():
            self.logger.info('shutdown ray instance to avoid '
                             '"Calling ray.init() again" error.')
            ray.shutdown()
        self.model = LLM(path, **model_kwargs)

    def generate(self,
                 inputs: List[str],
                 max_out_len: int,
                 stopping_criteria: List[str] = [],
                 **kwargs) -> List[str]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """

        if self.mode == 'mid':
            input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
            inputs = []
            for input_id in input_ids:
                if len(input_id) > self.max_seq_len - max_out_len:
                    half = int((self.max_seq_len - max_out_len) / 2)
                    inputs.append(
                        self.tokenizer.decode(input_id[:half],
                                              skip_special_tokens=True) +
                        self.tokenizer.decode(input_id[-half:],
                                              skip_special_tokens=True))
                else:
                    inputs.append(
                        self.tokenizer.decode(input_id,
                                              skip_special_tokens=True))

        generation_kwargs = kwargs.copy()
        generation_kwargs.update(self.generation_kwargs)
        generation_kwargs.update({'max_tokens': max_out_len})
        _stop = list(set(self.stop_words + stopping_criteria))
        generation_kwargs.update({'stop': _stop})
        sampling_kwargs = SamplingParams(**generation_kwargs)
        if not self.lora_path:
            outputs = self.model.generate(inputs, sampling_kwargs)
        else:
            outputs = self.model.generate(inputs,
                                          sampling_kwargs,
                                          lora_request=LoRARequest(
                                              'sql_adapter', 1,
                                              self.lora_path))

        prompt_list, output_strs = [], []
        for output in outputs:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            prompt_list.append(prompt)
            output_strs.append(generated_text)

        return output_strs

    def get_ppl(self,
                inputs: List[str],
                mask_length: Optional[List[int]] = None) -> List[float]:
        batch_size = len(inputs)
        sampling_kwargs = SamplingParams(prompt_logprobs=0,
                                         **self.generation_kwargs)
        # forward
        outputs = self.model.generate(inputs, sampling_kwargs)
        # compute ppl
        ce_loss = []
        for i in range(batch_size):
            prompt_logprobs = outputs[i].prompt_logprobs[1:]
            prompt_token_ids = outputs[i].prompt_token_ids[1:]
            prompt_logprobs_list = [
                prompt_logprobs[i][prompt_token_ids[i]]
                for i in range(len(prompt_logprobs))
            ]
            prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list]
            prompt_logprobs_list = np.array(prompt_logprobs_list)
            if mask_length is not None:
                prompt_logprobs_list = prompt_logprobs_list[-mask_length[i]:]
            loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids)
            ce_loss.append(loss)
        return np.array(ce_loss)

    def get_loglikelihood(self, inputs: List[str],
                          conts: List[str]) -> List[float]:
        mask_length = [
            self.get_token_len(c, add_special_tokens=False) for c in conts
        ]
        return -self.get_ppl(inputs, mask_length)

    def get_token_len(self,
                      prompt: str,
                      add_special_tokens: bool = True) -> int:
        """Get lengths of the tokenized strings.

        Args:
            prompt (str): Input string.

        Returns:
            int: Length of the input tokens
        """
        tokenizer = self.model.get_tokenizer()
        token_ids = tokenizer.encode(prompt,
                                     add_special_tokens=add_special_tokens)
        return len(token_ids)
