# Copyright (c) Alibaba, Inc. and its affiliates.
import asyncio
import hashlib
import inspect
import pickle
import time
from copy import deepcopy
from queue import Queue
from threading import Thread
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union

import json
import torch
from tqdm import tqdm
from transformers import GenerationConfig, LogitsProcessorList
from transformers.utils import is_torch_npu_available

from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device
from swift.plugin import Metric
from swift.tuners import Swift
from swift.utils import get_logger
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
                        ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
from .infer_engine import InferEngine
from .utils import AdapterRequest, InferStreamer, LogitsStreamer, TokensIteratorStreamer, prepare_generation_config

import requests
import re
import os

import pdb

logger = get_logger()


class _GenerationConfig(GenerationConfig):

    def __repr__(self) -> str:
        parameters = inspect.signature(self.to_json_string).parameters
        kwargs = {}
        if 'ignore_metadata' in parameters:
            kwargs['ignore_metadata'] = True
        gen_kwargs = json.loads(self.to_json_string(**kwargs))
        gen_kwargs.pop('transformers_version', None)
        return f'GenerationConfig({gen_kwargs})'


class PtEngine(InferEngine):

    def __init__(
            self,
            model_id_or_path: str,
            torch_dtype: Optional[torch.dtype] = None,
            *,
            adapters: List[str] = None,
            max_batch_size: int = 1,
            #
            model_type: Optional[str] = None,
            use_hf: Optional[bool] = None,
            revision: Optional[str] = None,
            hub_token: Optional[str] = None,
            load_model: bool = True,
            # model kwargs
            attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None,
            device_map: Optional[Union[str, Dict[str, Any]]] = None,
            quantization_config: Optional[Dict[str, Any]] = None,
            model_kwargs: Optional[Dict[str, Any]] = None,
            **kwargs):
        self.model, self.processor = get_model_tokenizer(
            model_id_or_path,
            torch_dtype,
            load_model=load_model,
            model_type=model_type,
            download_model=True,
            use_hf=use_hf,
            hub_token=hub_token,
            revision=revision,
            device_map=device_map,
            quantization_config=quantization_config,
            attn_impl=attn_impl,
            model_kwargs=model_kwargs,
            **kwargs)
        self.max_batch_size = max_batch_size
        if isinstance(adapters, str):
            adapters = [adapters]
        self.adapters = adapters or []
        for adapter in self.adapters:
            self._add_adapter(safe_snapshot_download(adapter, use_hf=use_hf, hub_token=hub_token))
        self._post_init()

    def _post_init(self):
        super()._post_init()
        self.engine = self.model  # dummy
        self.generation_config = self.model.generation_config
        self._queue = Queue()
        self._task_pool = {}
        self._task_thread = None

    def _start_infer_worker(self):
        if self._task_thread is None:
            self._task_thread = Thread(target=self._infer_worker)
            self._task_thread.daemon = True
            self._task_thread.start()

    def _fetch_infer_requests(self):
        while not self._queue.empty():
            infer_request, kwargs, queue = self._queue.get()
            template = kwargs['template']
            info = hashlib.sha256(pickle.dumps((kwargs['request_config'], template
                                                and template.template_meta))).hexdigest()
            if info not in self._task_pool:
                self._task_pool[info] = kwargs, []
            self._task_pool[info][1].append((infer_request, queue))
        if len(self._task_pool) == 0:
            return
        key, (kwargs, data) = next(iter(self._task_pool.items()))
        max_batch_size = self.max_batch_size or len(data)
        data, remain_data = data[:max_batch_size], data[max_batch_size:]
        if remain_data:
            self._task_pool[key] = kwargs, remain_data
        else:
            self._task_pool.pop(key)
        kwargs = kwargs.copy()
        kwargs['infer_requests'] = [d[0] for d in data]
        queue_list = [d[1] for d in data]
        return kwargs, queue_list

    def _infer_worker(self):
        while True:
            time.sleep(0.01)
            item = self._fetch_infer_requests()
            if item is not None:
                kwargs, queue_list = item
                request_config = kwargs['request_config']
                res_list_or_gen = self._infer(**kwargs)
                if request_config.stream:
                    finished = False
                    while not finished:
                        try:
                            res_list = next(res_list_or_gen)
                        except StopIteration:
                            finished = True
                            res_list = [None] * len(queue_list)
                        for (queue, loop), res in zip(queue_list, res_list):
                            asyncio.run_coroutine_threadsafe(queue.put(res), loop)
                else:
                    for (queue, loop), res in zip(queue_list, res_list_or_gen):
                        asyncio.run_coroutine_threadsafe(queue.put(res), loop)

    def _add_adapter(self, adapter_path: str, adapter_name: Optional[str] = None) -> None:
        self.model = Swift.from_pretrained(self.model, adapter_path, adapter_name)

    @classmethod
    def from_model_template(cls, model, template=None, *, max_batch_size: int = 1):
        self = super().__new__(cls)
        self.model = model
        self.default_template = template
        self.processor = template.processor
        self.max_batch_size = max_batch_size
        self._post_init()
        return self

    def _prepare_generation_config(self, request_config: RequestConfig) -> _GenerationConfig:
        generation_config = prepare_generation_config(self.generation_config, request_config, self.tokenizer)
        generation_config.return_dict_in_generate = True
        # if request_config.logprobs:
        generation_config.output_logits = True
        # generation_config.top_logprobs = request_config.top_logprobs
        generation_config.top_logprobs = 0
        generation_config.num_return_sequences = request_config.n
        return _GenerationConfig(**generation_config.to_dict())

    def _add_stop_words(self, generation_config: _GenerationConfig, request_config: RequestConfig,
                        template_meta: TemplateMeta) -> None:
        stop_words = (request_config.stop or []) + template_meta.stop_words
        generation_config.stop_words = self._get_stop_words(stop_words)

    @staticmethod
    def preprocess_logits(batched_logits: Optional[List[torch.Tensor]], batched_generate_ids: torch.Tensor,
                          top_logprobs: int):
        batch_size = batched_generate_ids.shape[0]
        if batched_logits is None:
            return None
        batched_logprobs = []
        for i in range(batch_size):
            logprobs_list = []
            generate_ids = batched_generate_ids[i]
            for j, logits in enumerate(batched_logits):
                token = generate_ids[j].item()
                logprobs = torch.log_softmax(logits[i], -1)
                tokens = [token] + logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist()
                logprobs_list.append({token: logprobs[token].item() for token in tokens})
            batched_logprobs.append(logprobs_list)
        return batched_logprobs

    @staticmethod
    def _update_batched_logprobs(batched_logprobs: List[torch.Tensor], logits_streamer: Optional[LogitsStreamer],
                                 generate_ids: torch.Tensor, top_logprobs: int) -> None:
        seq_len = generate_ids.shape[1] - len(batched_logprobs[0])
        if logits_streamer is None or seq_len == 0:
            return

        res = []
        for i in range(seq_len):
            res.append(logits_streamer.queue.get())
        new_batched_logprobs = PtEngine.preprocess_logits(res, generate_ids[:, -seq_len:], top_logprobs)
        for logprobs, new_logprobs in zip(batched_logprobs, new_batched_logprobs):
            logprobs += new_logprobs

    def _infer_stream(self,
                      template: Template,
                      inputs: Dict[str, Any],
                      *,
                      generation_config: GenerationConfig,
                      adapter_request: Optional[AdapterRequest] = None,
                      **kwargs) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]:

        if generation_config.num_beams != 1:
            error_msg = 'Streaming generation does not support beam search.'
            raise ValueError(error_msg)
        streamer = TokensIteratorStreamer()
        generate_kwargs = {
            'generation_config': generation_config,
            'streamer': streamer,
            **inputs,
        }
        adapter_names = self._get_adapter_names(adapter_request)
        if adapter_names is not None:
            generate_kwargs['adapter_names'] = adapter_names
        num_prompt_tokens = self._get_num_tokens(inputs)

        logits_streamer = None
        if generation_config.output_logits:
            generate_kwargs['logits_processor'] = LogitsProcessorList([LogitsStreamer()])

        def _model_generate(**kwargs):
            if is_torch_npu_available():
                torch.npu.set_device(self.model.device)
            template.generate(self.model, **kwargs)

        generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
        thread = Thread(target=_model_generate, kwargs=generate_kwargs)
        thread.start()
        batch_size = inputs['attention_mask'].shape[0]
        all_is_finished = False
        is_finished = [False] * batch_size
        infer_streamers = [InferStreamer(template) for _ in range(batch_size)]
        request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)]
        token_idxs = [0] * batch_size

        raw_batched_generate_ids = None  # or torch.Tensor: [batch_size, seq_len]
        batched_logprobs = [[] for _ in range(batch_size)]
        while not all_is_finished:
            try:
                batched_tokens = next(streamer)
                if batched_tokens.ndim == 1:
                    batched_tokens = batched_tokens[:, None]

                raw_batched_generate_ids = torch.concat(
                    [batched_tokens]
                    if raw_batched_generate_ids is None else [raw_batched_generate_ids, batched_tokens],
                    dim=1)
            except StopIteration:
                all_is_finished = True

            batched_generate_ids = template.get_generate_ids(raw_batched_generate_ids, num_prompt_tokens)
            self._update_batched_logprobs(batched_logprobs, logits_streamer, batched_generate_ids,
                                          generation_config.top_logprobs or 1)

            res = []
            for i in range(batched_generate_ids.shape[0]):
                if is_finished[i]:
                    res.append(None)
                    continue
                generate_ids = batched_generate_ids[i]

                # ignore pad_token
                masks = generate_ids != self.tokenizer.pad_token_id
                generate_ids = generate_ids[masks].tolist()
                logprobs_list = None
                if batched_logprobs[i]:
                    logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()]

                is_finished[i] = (
                    all_is_finished or is_finished[i]
                    or len(generate_ids) > 0 and generate_ids[-1] == self.tokenizer.pad_token_id)
                delta_text = infer_streamers[i].get_printable_text(generate_ids, is_finished[i])
                if not delta_text and not is_finished[i]:
                    res.append(None)
                    continue
                logprobs = self._get_logprobs(logprobs_list, generate_ids[token_idxs[i]:],
                                              generation_config.top_logprobs)
                token_idxs[i] = len(generate_ids)

                usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
                toolcall = None
                if is_finished[i]:
                    toolcall = self._get_toolcall(template.decode(generate_ids), template.tools_prompt)
                finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens,
                                                        is_finished[i])

                choices = [
                    ChatCompletionResponseStreamChoice(
                        index=0,
                        delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
                        finish_reason=finish_reason,
                        logprobs=logprobs)
                ]
                res.append(
                    ChatCompletionStreamResponse(
                        model=self.model_name, choices=choices, usage=usage_info, id=request_id_list[i]))
            if any(res):
                yield res

    def _get_adapter_names(self, adapter_request: Optional[AdapterRequest]) -> Optional[List[str]]:
        if adapter_request is None:
            if self._adapters_pool:
                return ['__base__']
            return
        adapter_name = adapter_request.name
        if adapter_name not in self._adapters_pool:
            self._adapters_pool[adapter_name] = adapter_request
            self._add_adapter(adapter_request.path, adapter_name)
        return [adapter_name]

    def _infer_forward(self,
                       template: Template,
                       inputs: Dict[str, Any],
                       adapter_request: Optional[AdapterRequest] = None,
                       **kwargs):
        call_kwargs = {}
        adapter_names = self._get_adapter_names(adapter_request)
        if adapter_names is not None:
            call_kwargs['adapter_names'] = adapter_names
        num_prompt_tokens = self._get_num_tokens(inputs)
        inputs.pop('labels', None)
        logits = self.model(**inputs, **call_kwargs).logits
        if template.mode == 'seq_cls':
            preds, logprobs = template.decode_seq_cls(logits)
        elif template.mode == 'prm':
            preds = template.decode_prm(inputs['input_ids'], logits)
            logprobs = [None] * len(preds)
        else:
            raise ValueError(f'Unsupported mode: {template.mode}')

        res = []
        for i, pred in enumerate(preds):
            usage_info = self._get_usage_info(num_prompt_tokens, 1)
            choices = [
                ChatCompletionResponseChoice(
                    index=0,
                    message=ChatMessage(role='assistant', content=str(pred), tool_calls=None),
                    finish_reason='stop',
                    logprobs=logprobs[i])
            ]
            res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
        return res

    def _infer_full(self,
                    template: Template,
                    inputs: Dict[str, Any],
                    *,
                    generation_config: GenerationConfig,
                    adapter_request: Optional[AdapterRequest] = None,
                    template_inputs=None) -> List[ChatCompletionResponse]:
        # bos_token TODO: encoder-decoder
        generate_kwargs = {'generation_config': generation_config, **inputs}
        adapter_names = self._get_adapter_names(adapter_request)
        if adapter_names is not None:
            generate_kwargs['adapter_names'] = adapter_names
        num_prompt_tokens = self._get_num_tokens(inputs)
        generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
        
        output = dict(template.generate(self.model, **generate_kwargs))
        output.pop('past_key_values', None)
        batched_generate_ids = output['sequences']
        batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens)
        template.debug_logger({'generate_ids': batched_generate_ids})  # debug
        batched_logprobs = self.preprocess_logits(
            output.get('logits'), batched_generate_ids, generation_config.top_logprobs)

        res = []
        num_return_sequences = generation_config.num_return_sequences
        for i in range(inputs['attention_mask'].shape[0]):
            choices = []
            usage_info = self._get_usage_info(num_prompt_tokens, 0)
            for j in range(num_return_sequences):
                batched_index = i * num_return_sequences + j
                generate_ids = batched_generate_ids[batched_index]

                # ignore pad_token
                masks = generate_ids != self.tokenizer.pad_token_id
                generate_ids = generate_ids[masks].tolist()
                logprobs_list = None
                if batched_logprobs is not None:
                    logprobs_list = [
                        logprobs for m, logprobs in zip(masks, batched_logprobs[batched_index]) if m.item()
                    ]

                logprobs = self._get_logprobs(logprobs_list, generate_ids, generation_config.top_logprobs)
                usage_info = self._update_usage_info(usage_info, len(generate_ids))
                response = template.decode(generate_ids, template_inputs=template_inputs[i])
                finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
                toolcall = self._get_toolcall(response, template.tools_prompt)

                # get answer probs
                answer_entries = []
                content = logprobs['content']
                max_tag_len = 32
                buffer = ""
                collecting = False
                raw_entries = []

                for entry in content:
                    t = entry['token']
                    next_buffer = (buffer + t)[-max_tag_len:]

                    if not collecting:
                        if next_buffer.endswith('<answer>'):
                            collecting = True
                            buffer = ""
                        else:
                            buffer = next_buffer
                        continue

                    # collecting
                    raw_entries.append(entry)
                    buffer = next_buffer

                # Trim any "</answer>" that may have been partially collected at the end
                if raw_entries:
                    raw_text = ''.join(e['token'] for e in raw_entries)
                    close_pos = raw_text.find('</answer>')
                    if close_pos != -1:
                        cut_len = 0
                        char_count = 0
                        for e in raw_entries:
                            tok = e['token']
                            if char_count + len(tok) > close_pos:
                                break
                            char_count += len(tok)
                            cut_len += 1
                        answer_entries = raw_entries[:cut_len]
                    else:
                        answer_entries = raw_entries

                choices.append(
                    ChatCompletionResponseChoice(
                        index=j,
                        message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
                        finish_reason=finish_reason,
                        logprobs=answer_entries))
            res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
        return res

    async def infer_async(
        self,
        infer_request: InferRequest,
        request_config: Optional[RequestConfig] = None,
        *,
        template: Optional[Template] = None,
        adapter_request: Optional[AdapterRequest] = None,
        pre_infer_hook=None,
    ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
        if request_config is None:
            request_config = RequestConfig()
        queue = asyncio.Queue()
        self._queue.put((infer_request, {
            'request_config': request_config,
            'template': template,
            'adapter_request': adapter_request,
            'pre_infer_hook': pre_infer_hook
        }, (queue, asyncio.get_event_loop())))
        await asyncio.sleep(0)
        self._start_infer_worker()
        if request_config.stream:

            async def _gen_wrapper():
                while True:
                    item = await queue.get()
                    await asyncio.sleep(0)
                    if item is None:
                        break
                    yield item

            return _gen_wrapper()
        else:
            return await queue.get()

    @staticmethod
    def _add_error_list(outputs, error_list):
        for i, error in error_list:
            outputs.insert(i, error)
        return outputs

    # Ensure `template._post_encode` has no gradient.
    @torch.inference_mode()
    def _infer(
        self,
        infer_requests: List[InferRequest],
        request_config: RequestConfig,
        *,
        template: Optional[Template] = None,
        adapter_request: Optional[AdapterRequest] = None,
        pre_infer_hook=None,
    ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
        self.model.eval()
        request_config = deepcopy(request_config)
        if template is None:
            template = self.default_template
        if template.use_model:
            template.model = self.model

        generation_config = None
        if self.model_info.task_type == 'causal_lm':
            template.set_mode('pt')
        
        inputs_cache = [input_ for input_, _, _, _ in infer_requests]
        response_cache = [response_ for _, response_, _, _ in infer_requests]
        query_cache = [query_ for _, _, query_, _ in infer_requests]
        doc_cache = [doc_ for _, _, _, doc_ in infer_requests]
        
        batched_inputs_, error_list = self._batch_encode(
            inputs_cache, template=template, strict=getattr(self, 'strict', True))
        
        response_ids = None

        if len(response_cache[0]) > 0:
            response_ids = [self.tokenizer(response_c)["input_ids"] for response_c in response_cache]
            query_ids = [self.tokenizer(query_c)["input_ids"] for query_c in query_cache]
            doc_ids = [self.tokenizer(doc_c)["input_ids"] for doc_c in doc_cache]

        batched_inputs = deepcopy(batched_inputs_)
        for batch_idx, batched_input in enumerate(batched_inputs):
            if response_ids is not None:
                for r, q, d in zip(response_ids[batch_idx], query_ids[batch_idx], doc_ids[batch_idx]):
                    batched_input['input_ids'] = batched_input['input_ids'] + r + q + d

        
        if len(batched_inputs) > 0:
            template_inputs = [inputs.pop('template_inputs') for inputs in batched_inputs]
            inputs = to_device(template.data_collator(batched_inputs), self.model.device)   # 开始有attention mask
            template.debug_logger(inputs)  # debug
            if self.model.model_meta.is_multimodal:
                _, inputs = template.pre_forward_hook(self.model, None, inputs)
            if self.model_info.task_type == 'causal_lm':
                self.set_default_max_tokens(request_config, inputs)
                generation_config = self._prepare_generation_config(request_config)
                self._add_stop_words(generation_config, request_config, template.template_meta)

            kwargs = {
                'template': template,
                'inputs': inputs,
                'generation_config': generation_config,
                'adapter_request': adapter_request,
                'template_inputs': template_inputs
            }
            if pre_infer_hook:
                kwargs = pre_infer_hook(kwargs)
        else:
            kwargs = {}
        if request_config.stream:

            def _gen_wrapper():
                if len(kwargs) > 0:
                    for res in self._infer_stream(**kwargs):
                        yield self._add_error_list(res, error_list)
                else:
                    yield self._add_error_list([], error_list)

            return _gen_wrapper()
        else:
            if len(kwargs) > 0:
                infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm') else self._infer_full
                res = infer_func(**kwargs)
            else:
                res = []
            return self._add_error_list(res, error_list)

    def infer(
        self,
        infer_requests: List[InferRequest],
        request_config: Optional[RequestConfig] = None,
        metrics: Optional[List[Metric]] = None,
        *,
        template: Optional[Template] = None,
        use_tqdm: Optional[bool] = None,
        adapter_request: Optional[AdapterRequest] = None
    ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
        if request_config is None:
            request_config = RequestConfig()
        if request_config.stream:
            return super().infer(
                infer_requests,
                request_config,
                metrics,
                template=template,
                use_tqdm=use_tqdm,
                adapter_request=adapter_request)
        # Has higher stability than calling super().infer
        if use_tqdm is None:
            use_tqdm = not request_config.stream and len(infer_requests) > 1
        prog_bar = tqdm(total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm)
        # If self.max_batch_size is None or 0, then process all infer_requests at once.
        max_batch_size = self.max_batch_size or len(infer_requests)
        res = []
        i = 0
        while i < len(infer_requests):
            infer_requests_samples = infer_requests[i:i + max_batch_size]

            retrieve_url = "http://127.0.0.1:5003/queries"
            finished_all = []
            inputs_cache = infer_requests_samples.copy()
            inputs_num = len(infer_requests_samples)
            inputs_idxes = [[idx, input_, [], [], []] for idx, input_ in enumerate(infer_requests_samples)]
            cache_outputs = [None for _ in range(inputs_num)]
            all_input_idxes = [[idx, input_, [], [], []] for idx, input_ in enumerate(infer_requests_samples)]

            max_retrieve_times = 3
            for t in range(max_retrieve_times):
                inputs_cache = [[input_, response_, query_, doc_] for _, input_, response_, query_, doc_ in inputs_idxes]
                outputs = self._infer(inputs_cache, request_config, template=template, adapter_request=adapter_request)

                query_list, new_inputs_idxes, finished_outputs = [], [], []
                for ii, (input_, output_) in enumerate(zip(inputs_idxes, outputs)):
                    # pdb.set_trace()
                    output_text = output_.choices[0].message.content
                    # output_text = output_text + "<query> dynamic ultrasonographic sign </query>"
                        
                    if "<query>" in output_text and "</query>" in output_text:
                        ## TODO: multiple queries in one sentence
                        query = output_text.split("<query>")[1].split("</query>")[0]
                        query = query.replace('"',"").strip()
                        query = " ".join(query.split())
                        if query:
                            response = output_text.split("<query>")[0]
                            query_list.append(query)
                            
                            new_inputs_idxes.append([input_[0], input_[1], input_[2] + [response + "\n"], input_[3] + [f"<query> {query} </query>\n"], input_[4]])
                            
                            if cache_outputs[input_[0]] is None:
                                cache_outputs[input_[0]] = output_
                                prompt_tokens = output_.usage.prompt_tokens
                                all_input_idxes[input_[0]] = input_
                            else:
                                new_output = deepcopy(output_)
                                new_output.choices[0].message.content = ""
                                for r, q, d in zip(input_[2], input_[3], input_[4]):
                                    new_output.choices[0].message.content += r + q + d
                                new_output.choices[0].message.content += output_.choices[0].message.content
                                prompt_tokens = cache_outputs[input_[0]].usage.prompt_tokens
                                new_output.usage.completion_tokens += new_output.usage.prompt_tokens - prompt_tokens
                                new_output.usage.prompt_tokens = prompt_tokens

                                cache_outputs[input_[0]] = new_output
                                all_input_idxes[input_[0]] = input_
                        else:
                            if cache_outputs[input_[0]] is None:
                                cache_outputs[input_[0]] = output_
                                all_input_idxes[input_[0]] = input_
                            else:
                                new_output = deepcopy(output_)
                                new_output.choices[0].message.content = ""
                                for r, q, d in zip(input_[2], input_[3], input_[4]):
                                    new_output.choices[0].message.content += r + q + d
                                new_output.choices[0].message.content += output_.choices[0].message.content
                                prompt_tokens = cache_outputs[input_[0]].usage.prompt_tokens
                                new_output.usage.completion_tokens += new_output.usage.prompt_tokens - prompt_tokens
                                new_output.usage.prompt_tokens = prompt_tokens
                                cache_outputs[input_[0]] = new_output
                                all_input_idxes[input_[0]] = input_

                    else:
                        if cache_outputs[input_[0]] is None:
                            cache_outputs[input_[0]] = output_
                            all_input_idxes[input_[0]] = input_
                        else:
                            new_output = deepcopy(output_)
                            new_output.choices[0].message.content = ""
                            for r, q, d in zip(input_[2], input_[3], input_[4]):
                                new_output.choices[0].message.content += r + q + d
                            new_output.choices[0].message.content += output_.choices[0].message.content
                            prompt_tokens = cache_outputs[input_[0]].usage.prompt_tokens
                            new_output.usage.completion_tokens += new_output.usage.prompt_tokens - prompt_tokens
                            new_output.usage.prompt_tokens = prompt_tokens
                            cache_outputs[input_[0]] = new_output
                            all_input_idxes[input_[0]] = input_
            
                if len(query_list) > 0:
                    topk = 3
                    re_response = requests.post(retrieve_url, json={"queries": query_list, "k": topk})
                    if re_response.status_code == 200:
                        retrieve_result = re_response.json()
                        retrieve_answers = retrieve_result["answers"]

                        for k in range(len(retrieve_answers)):
                            retrieve_docs = retrieve_answers[k]
                            if len(retrieve_docs) > 0:
                                doc_content_list = []
                                for j in range(len(retrieve_docs)):
                                    doc_now = re.sub(r'^\d+\s+', '', retrieve_docs[j])
                                    doc_content_list.append(f"({j+1}) {doc_now}\n")
                                doc_content = ''.join(doc_content_list)
                            else:
                                doc_content = "None"
                            new_inputs_idxes[k][-1].append(f"<retrieve> {doc_content} </retrieve>\n")
                            
                if len(new_inputs_idxes) == 0:
                    finished_all = cache_outputs
                    assert len(finished_all) == inputs_num
                    break
                else:
                    if t < max_retrieve_times - 1:
                        inputs_idxes = new_inputs_idxes
                    else:
                        finished_all = cache_outputs

            with open("similarity_dict.json", "r") as f:
                self.similarity_dict = json.load(f)
            # select low conf, and add knowledge retrieved from multimodal database
            finished_all[0].choices[0].logprobs[0]['prob']
            new_inputs_idxes = []

            for ii, (input_idx, finished_item) in enumerate(zip(all_input_idxes, finished_all)):
                prob = finished_item.choices[0].logprobs[0]['prob']
                line_id = str(input_idx[1]["id"]) # get line id
                if prob < 0.9:
                    input_message = input_idx[1]
                    input_message["images"].append({'bytes': None, 'path': os.path.join("", self.similarity_dict[line_id]["related_image"])})
                    if "<image 1>" in input_message["messages"][1]:
                        input_message["messages"][1]["content"] = input_message["messages"][1]["content"].replace("<image 1>", "<image>").replace("<image 2>", "<image>").replace("<image 3>", "<image>").replace("<image 4>", "<image>").replace("<image 5>", "<image>")
                    else:
                        input_message["messages"][1]["content"] = "<image>" + input_message["messages"][1]["content"]
                    conversation = self.similarity_dict[line_id]["conversation"]
                    new_inputs_idxes.append([ii, input_message, input_idx[2] + ["<think> I need to find a consultation from a case with similar medical imaging findings to provide relevant context and reference. I will use the consultation with previously retrieved information to generate the answer. </think>\n"], input_idx[3] + ["<query> consultation from similar medical imaging case </query>"], input_idx[4] + [f"<retrieve> <image> {conversation} </retrieve>\n"]])
                else:
                    cache_outputs[input_idx[0]] = finished_item
            
            if len(new_inputs_idxes) > 0:
                inputs_idxes = new_inputs_idxes
                inputs_cache = [[input_, response_, query_, doc_] for _, input_, response_, query_, doc_ in inputs_idxes]
                outputs = self._infer(inputs_cache, request_config, template=template, adapter_request=adapter_request)

                for ii, (input_, output_) in enumerate(zip(inputs_idxes, outputs)):
                    new_output = deepcopy(output_)
                    # new_output.choices[0].message.content = ""
                    for r, q, d in zip(input_[2], input_[3], input_[4]):
                        new_output.choices[0].message.content += r + q + d
                    # pdb.set_trace()
                    new_output.choices[0].message.content += output_.choices[0].message.content
                    prompt_tokens = cache_outputs[input_[0]].usage.prompt_tokens
                    new_output.usage.completion_tokens += new_output.usage.prompt_tokens - prompt_tokens
                    new_output.usage.prompt_tokens = prompt_tokens
                    cache_outputs[input_[0]] = new_output
                # all_input_idxes[input_[0]] = input_

            finished_all = cache_outputs
            
            res += finished_all            
            i += max_batch_size
            prog_bar.update(len(infer_requests_samples))

        self._update_metrics(res, metrics)
        return res
