# Copyright (c) Alibaba, Inc. and its affiliates.
import asyncio
import inspect
import os
import time
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

import lmdeploy
import torch
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig, pipeline
from lmdeploy.api import autoget_backend_config
from lmdeploy.serve import async_engine
from packaging import version
from transformers import GenerationConfig

from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer
from swift.plugin import Metric
from swift.utils import get_logger, get_seed
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
                        ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig)
from .infer_engine import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .utils import InferStreamer, patch_lmdeploy

try:
    from lmdeploy import EngineGenerationConfig as LmdeployGenerationConfig
except ImportError:
    # compat lmdeploy >= 0.6.*
    from lmdeploy import GenerationConfig as LmdeployGenerationConfig

logger = get_logger()


class LmdeployEngine(InferEngine):

    def __init__(
        self,
        model_id_or_path: str,
        torch_dtype: Optional[torch.dtype] = None,
        *,
        model_type: Optional[str] = None,
        use_hf: Optional[bool] = None,
        hub_token: Optional[str] = None,
        revision: Optional[str] = None,
        # engine_kwargs
        tp: int = 1,
        session_len: Optional[int] = None,
        cache_max_entry_count: float = 0.8,
        quant_policy: int = 0,  # e.g. 4, 8
        vision_batch_size: int = 1,  # max_batch_size in VisionConfig
        devices: Optional[List[int]] = None,
        reload_weights: bool = False,
        engine_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0')
        if reload_weights:
            assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0'
        if version_7 and tp == 1:
            patch_lmdeploy(reload_weights)
        self.processor = get_model_tokenizer(
            model_id_or_path,
            torch_dtype,
            load_model=False,
            download_model=True,
            model_type=model_type,
            use_hf=use_hf,
            hub_token=hub_token,
            revision=revision)[1]
        self._post_init()

        if self.max_model_len is not None:
            self.max_model_len -= 1
        self._prepare_engine_kwargs(
            tp=tp,
            session_len=session_len,
            cache_max_entry_count=cache_max_entry_count,
            quant_policy=quant_policy,
            vision_batch_size=vision_batch_size,
            devices=devices,
            engine_kwargs=engine_kwargs)

        self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype

        @contextmanager
        def disable_deepspeed():
            from transformers import modeling_utils
            modeling_utils.is_deepspeed_zero3_enabled_origin = modeling_utils.is_deepspeed_zero3_enabled
            modeling_utils.is_deepspeed_zero3_enabled = lambda: False
            yield
            modeling_utils.is_deepspeed_zero3_enabled = modeling_utils.is_deepspeed_zero3_enabled_origin
            del modeling_utils.is_deepspeed_zero3_enabled_origin

        with disable_deepspeed():
            self._prepare_engine()
        self._load_generation_config()

    def _prepare_engine_kwargs(self,
                               tp: int = 1,
                               session_len: Optional[int] = None,
                               cache_max_entry_count: float = 0.8,
                               quant_policy: int = 0,
                               vision_batch_size: int = 1,
                               devices: Optional[List[int]] = None,
                               engine_kwargs: Optional[Dict[str, Any]] = None):
        if engine_kwargs is None:
            engine_kwargs = {}
        engine_kwargs['tp'] = tp
        engine_kwargs['session_len'] = session_len
        engine_kwargs['cache_max_entry_count'] = cache_max_entry_count
        engine_kwargs['quant_policy'] = quant_policy
        backend_config = TurbomindEngineConfig(**engine_kwargs)
        backend_config = autoget_backend_config(self.model_dir, backend_config)
        if hasattr(backend_config, 'devices'):
            if devices is None:
                devices = [0]
            backend_config.devices = devices
        self.backend_config = backend_config
        logger.info(f'backend_config: {backend_config}')

        pipeline_kwargs = {}
        is_multimodal = self.model_meta.is_multimodal
        if is_multimodal:
            vision_config = VisionConfig(max_batch_size=vision_batch_size)
            pipeline_kwargs['vision_config'] = vision_config
            logger.info(f'vision_config: {vision_config}')
        self.pipeline_kwargs = pipeline_kwargs

    @contextmanager
    def _patch_pipeline(self):
        _old_best_match_model = async_engine.best_match_model

        def _best_match_model(*args, **kwargs) -> Optional[str]:
            return self.model_info.model_type

        async_engine.best_match_model = _best_match_model
        try:
            yield
        finally:
            async_engine.best_match_model = _old_best_match_model

    def _prepare_engine(self):
        with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config), self._patch_pipeline():
            engine = pipeline(self.model_dir, backend_config=self.backend_config, **self.pipeline_kwargs)
        self.engine = engine

    def _load_generation_config(self):
        generation_config_path = os.path.join(self.model_dir, 'generation_config.json')
        if os.path.isfile(generation_config_path):
            generation_config = GenerationConfig.from_pretrained(self.model_dir)
            kwargs = generation_config.to_dict()
            max_new_tokens = kwargs.get('max_new_tokens')
            if max_new_tokens is None:
                kwargs.pop('max_new_tokens', None)
            parameters = inspect.signature(LmdeployGenerationConfig).parameters
            for k, v in kwargs.copy().items():
                if k not in parameters or v is None:
                    kwargs.pop(k)
            self.generation_config = LmdeployGenerationConfig(**kwargs)
        else:
            self.generation_config = LmdeployGenerationConfig()

    def _get_stop_token_ids(self, stop_words: List[Union[str, List[int], None]]) -> List[int]:
        stop_token_ids: List[int] = []
        for stop_word in stop_words:
            if stop_word is None:
                continue
            if isinstance(stop_word, str):
                stop_word = self.tokenizer.encode(stop_word, add_special_tokens=False)
            if isinstance(stop_word, list):
                if len(stop_word) != 1:
                    continue
                else:
                    stop_token = stop_word[0]
            elif isinstance(stop_word, int):
                stop_token = stop_word
            assert isinstance(stop_token, int)
            if stop_token not in stop_token_ids:
                stop_token_ids.append(stop_token)
        return stop_token_ids

    def _add_stop_words(self, generation_config: LmdeployGenerationConfig, request_config: RequestConfig,
                        template_meta: TemplateMeta) -> None:
        stop_words = (request_config.stop or []) + (self.generation_config.stop_words or []) + template_meta.stop_words
        generation_config.stop_words = self._get_stop_token_ids(stop_words)
        # compat lmdeploy >= 0.6.*
        generation_config.stop_token_ids = generation_config.stop_words

    def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployGenerationConfig:
        kwargs = {'max_new_tokens': request_config.max_tokens}
        for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
            new_value = getattr(request_config, key)
            if new_value is None:
                kwargs[key] = getattr(self.generation_config, key)
            else:
                kwargs[key] = new_value
        if request_config.seed is None:
            request_config.seed = get_seed()
        kwargs['random_seed'] = request_config.seed
        if request_config.temperature == 0:
            kwargs['temperature'] = 1  # avoid unnecessary process
            kwargs['top_k'] = 1

        if request_config.logprobs:
            kwargs['logprobs'] = 1
            if request_config.top_logprobs is not None:
                kwargs['logprobs'] = max(1, request_config.top_logprobs)

        res = LmdeployGenerationConfig(**kwargs)
        res.top_logprobs = request_config.top_logprobs
        return res

    async def _infer_stream_async(
            self, template: Template, inputs: Dict[str, Any],
            generation_config: LmdeployGenerationConfig) -> AsyncIterator[ChatCompletionStreamResponse]:
        session_id = time.time_ns()
        kwargs = {'stream_output': True, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True}
        if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
            async with self.engine.model_inst(session_id) as inst:
                context = self.engine.safe_run(inst, session_id, **inputs, **kwargs)
        else:
            context = self.engine.safe_run(session_id)

        infer_streamer = InferStreamer(template)
        token_idx = 0
        async with context as gen:
            if version.parse(lmdeploy.__version__) < version.parse('0.6.5'):
                generator = await self.engine.get_generator(False, session_id)
                gen = generator.async_stream_infer(session_id=session_id, **inputs, **kwargs)
            is_finished = False
            while not is_finished:
                try:
                    output = await gen.__anext__()
                except StopAsyncIteration:
                    is_finished = True
                delta_text = infer_streamer.get_printable_text(output.token_ids, is_finished)
                if not delta_text and not is_finished:
                    continue

                logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:],
                                              generation_config.top_logprobs)
                token_idx = len(output.token_ids)

                usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
                toolcall = None
                if is_finished:
                    toolcall = self._get_toolcall(template.decode(output.token_ids), template.tools_prompt)
                finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
                                                        output.status.name == 'FINISH')
                choices = [
                    ChatCompletionResponseStreamChoice(
                        index=0,
                        delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
                        finish_reason=finish_reason,
                        logprobs=logprobs)
                ]
                yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info)

    async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
                                generation_config: LmdeployGenerationConfig) -> ChatCompletionResponse:
        session_id = time.time_ns()
        kwargs = {'stream_output': False, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True}
        if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
            async with self.engine.model_inst(session_id) as inst:
                async with self.engine.safe_run(inst, session_id, **inputs, **kwargs) as gen:
                    async for output in gen:
                        pass
                if self.engine.backend == 'pytorch':
                    # manually end pytorch session
                    await inst.async_end(session_id)

        else:
            async with self.engine.safe_run(session_id):
                generator = await self.engine.get_generator(False, session_id)
                async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs):
                    pass

        response = template.decode(output.token_ids)
        logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs)

        usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
        toolcall = self._get_toolcall(response, template.tools_prompt)
        finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
                                                output.status.name == 'FINISH')
        choices = [
            ChatCompletionResponseChoice(
                index=0,
                message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
                finish_reason=finish_reason,
                logprobs=logprobs)
        ]
        return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)

    async def infer_async(self,
                          infer_request: InferRequest,
                          request_config: Optional[RequestConfig] = None,
                          *,
                          template: Optional[Template] = None,
                          pre_infer_hook=None,
                          **kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
        request_config = deepcopy(request_config or RequestConfig())
        if template is None:
            template = self.default_template

        template.set_mode('lmdeploy')

        loop = asyncio.get_running_loop()
        with torch.inference_mode():
            inputs = await loop.run_in_executor(None, template.encode, infer_request)
        images = inputs.pop('images', None)
        if images:
            if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
                messages = self.engine._convert_prompts(('', images))
                messages = await self.engine.async_convert_to_pil_images(messages)
                results = await self.engine.vl_encoder.preprocess(messages)
                if self.engine.backend == 'turbomind':
                    results = await self.engine.vl_encoder.async_infer(results)
                    inputs['images'] = [result['content'] for result in results if result['role'] == 'forward'][0]
                    await template.prepare_lmdeploy_turbomind_inputs(inputs)
                else:
                    inputs['images'] = results[1]['content']
                    await template.prepare_lmdeploy_pytorch_inputs(inputs)
            else:
                inputs['images'] = await self.engine.vl_encoder.async_infer(images)
                await template.prepare_lmdeploy_turbomind_inputs(inputs)

        self.set_default_max_tokens(request_config, inputs)
        generation_config = self._prepare_generation_config(request_config)
        self._add_stop_words(generation_config, request_config, template.template_meta)
        kwargs.update({'template': template, 'inputs': inputs, 'generation_config': generation_config})
        if pre_infer_hook:
            kwargs = pre_infer_hook(kwargs)
        if request_config.stream:
            return self._infer_stream_async(**kwargs)
        else:
            return await self._infer_full_async(**kwargs)

    def _batch_infer_stream(self, *args, **kwargs):
        if hasattr(self.engine, 'vl_encoder'):
            self.engine.vl_encoder._loop_task = None
        if hasattr(self.engine, 'free_insts'):
            self.engine.free_insts = None
        return super()._batch_infer_stream(*args, **kwargs)

    def infer(
        self,
        infer_requests: List[InferRequest],
        request_config: Optional[RequestConfig] = None,
        metrics: Optional[List[Metric]] = None,
        *,
        template: Optional[Template] = None,
        use_tqdm: Optional[bool] = None,
    ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
        return super().infer(infer_requests, request_config, metrics, template=template, use_tqdm=use_tqdm)
