from __future__ import annotations

import os
import sys
import warnings
import math
import logging
import json
from datetime import datetime
import re
import argparse
import json_repair
from typing import Dict, Optional, Union, List, Callable, Tuple


import torch
from transformers import StoppingCriteria

from ..base import BaseModel
from .prompt import Qwen2VLPromptMixin
from ...smp import get_rank_and_world_size, get_gpu_memory, listinstr
from ...dataset import DATASET_MODALITY

VLLM_MAX_IMAGE_INPUT_NUM = 24


def parse_json_text_with_remaining(
    raw_response: str,
    print_fail_resp: bool = False,
) -> Tuple[Optional[str], Optional[str]]:
    pattern = r'```json(.*?)```'
    # Use re.findall to extract all matches
    try:
        matches = re.findall(pattern, raw_response, re.DOTALL)
        if not matches:
            matches = [raw_response]
        if len(matches) == 1:
            match_text: str = matches[0].strip()

            formatted_response = json_repair.loads(match_text)
            if not formatted_response and print_fail_resp:
                print("=====")
                print(f"{raw_response}")
                print("=====")
        else:
            formatted_response = [json_repair.loads(match_text) for match_text in matches]
        
        if type(formatted_response) is dict:
            for k, v in formatted_response.items():
                if not v or type(v) is not str:
                    continue
                formatted_response[k] = v.replace("\\n", "\n")
        remaining_text = re.sub(pattern, '', raw_response, flags=re.DOTALL).strip()
        return formatted_response, remaining_text
    except Exception as e:  # noqa: F841
        # raise e
        print("Fail to parse one output.")
        if print_fail_resp:
            print("=====")
            print(f"{raw_response}")
            print("=====")
        return None, None

def ensure_image_url(image: str) -> str:
    # 定义合法的URL前缀列表
    prefixes = ['http://', 'https://', 'file://', 'data:image;']
    # 检查图片路径是否已经是合法URL格式
    if any(image.startswith(prefix) for prefix in prefixes):
        return image
    # 如果是本地文件路径，转换为file://格式
    if os.path.exists(image):
        return 'file://' + image
    # 如果都不符合，抛出错误
    raise ValueError(f'Invalid image: {image}')


def ensure_video_url(video: str) -> str:
    # 与ensure_image_url类似，但是处理视频URL
    # 唯一的区别是data:前缀使用data:video;而不是data:image;
    prefixes = ['http://', 'https://', 'file://', 'data:video;']
    if any(video.startswith(prefix) for prefix in prefixes):
        return video
    if os.path.exists(video):
        return 'file://' + video
    raise ValueError(f'Invalid video: {video}')


def create_image_content(image_path, min_pixels, max_pixels):
    base64_image, mime_type = encode_image(image_path)
    return {
        "type": "image",
        "image": f"data:{mime_type};base64,{base64_image}",
        'min_pixels': min_pixels,
        'max_pixels': max_pixels
    }


def encode_image(image_path, max_side=None):
    from mimetypes import guess_type
    mime_type, _ = guess_type(image_path)
    if mime_type is None:
        mime_type = "image/jpeg"
    image_format = mime_type.split("/")[-1].upper() if mime_type else "JPEG"

    from PIL import Image
    image = Image.open(image_path)
    # Handle the alpha channel
    if image.mode == "RGBA":
        image = _rgba_to_rgb(image)
    if max_side:
        image = _resize_image(image, max_side)
    encoded_image = _encode_image(image, image_format)

    return encoded_image, mime_type


def _encode_image(image, image_format):
    from io import BytesIO
    with BytesIO() as output:
        image.convert("RGB").save(output, format=image_format)
        import base64
        base64_encoded_data = base64.b64encode(output.getvalue()).decode("utf-8")
    return base64_encoded_data


def _rgba_to_rgb(image):
    from PIL import Image
    background = Image.new("RGBA", image.size, (255, 255, 255, 255))
    return Image.alpha_composite(background, image).convert("RGB")


def _resize_image(image, max_side):
    resize_scale = max_side / max(image.size)
    new_size = (
        int(image.size[0] * resize_scale),
        int(image.size[1] * resize_scale),
    )
    return image.resize(new_size)


def process_video(video_path, num_frames, min_pixels, max_pixels):
    import cv2
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)  # Frames per second

    # the sampling rate using max number of frames
    sampling_gap_maxframe = (
        1 if not num_frames else math.ceil(frame_count / num_frames)
    )
    sampling_gap = max(math.ceil(fps / 5), sampling_gap_maxframe)

    frame_number = 0
    images = []

    while True:
        import tempfile
        success, frame = cap.read()
        if not success:
            break
        # Sample frames based on the dynamic sampling rate
        if frame_number % sampling_gap == 0:
            # Create a temporary file for the frame
            with tempfile.NamedTemporaryFile(
                suffix=".jpg", delete=False
            ) as temp_frame:
                cv2.imwrite(temp_frame.name, frame)
                images.append(create_image_content(temp_frame.name, min_pixels, max_pixels))
                os.remove(temp_frame.name)
        frame_number += 1
    if frame_number == 0:
        raise ValueError(f"Failed to read video from {video_path}, check data...")
    logging.info(
        f"Sampled {len(images)}/{frame_number} frames from video {video_path}"
    )
    cap.release()
    return images


def setup_visible_devices_per_rank():
    total_gpus = torch.cuda.device_count()
    # 获取当前进程rank和总进程数
    rank, world_size = get_rank_and_world_size()
    assert world_size == 1, "Only support world_size == 1 for vLLM inference"
    num_gpus = total_gpus // world_size
    start_idx = rank * num_gpus
    assigned_devices = list(range(start_idx, start_idx + num_gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in assigned_devices)
    logging.info(f"[Rank {rank}] Visible GPUs: {assigned_devices}")
    return num_gpus


class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.keyword_ids = []
        self.max_keyword_len = 0
        for keyword in keywords:
            cur_keyword_ids = tokenizer(keyword).input_ids
            if (
                len(cur_keyword_ids) > 1
                and cur_keyword_ids[0] == tokenizer.bos_token_id
            ):
                cur_keyword_ids = cur_keyword_ids[1:]
            if len(cur_keyword_ids) > self.max_keyword_len:
                self.max_keyword_len = len(cur_keyword_ids)
            self.keyword_ids.append(torch.tensor(cur_keyword_ids))
        self.tokenizer = tokenizer
        self.start_len = input_ids.shape[1]

    def __call__(
        self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)"  # TODO
        offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
        self.keyword_ids = [
            keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
        ]
        for keyword_id in self.keyword_ids:
            if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
                return True
        outputs = self.tokenizer.batch_decode(
            output_ids[:, -offset:], skip_special_tokens=True
        )[0]
        for keyword in self.keywords:
            if keyword in outputs:
                return True
        return False


CHAT_TEMPLATE = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"  # noqa: E501

UNTIL = ["<|diff_marker|>"]

# 检查是否包含中文字符
def cn_string(s):
    import re
    # 检查字符串是否包含中文字符（Unicode范围：\u4e00-\u9fff）
    if re.search('[\u4e00-\u9fff]', s):
        return True
    return False


class Qwen2VLChat(Qwen2VLPromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        model_path: str,
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        total_pixels: int | None = None,
        max_new_tokens=2048,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        use_custom_prompt: bool = True,
        system_prompt: str | None = None,
        post_process: bool = False,  # if True, will try to only extract stuff in the last \boxed{}.
        verbose: bool = False,
        use_audio_in_video: bool = False,
        **kwargs,
    ):
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.total_pixels = total_pixels
        self.max_new_tokens = max_new_tokens
        if self.total_pixels and self.total_pixels > 24576 * 28 * 28:
            print('The total number of video tokens might become too large, resulting in an overly long input sequence. We recommend lowering **total_pixels** to below **24576 × 28 × 28**.')  # noqa: E501
        self.generate_kwargs = dict(
            max_new_tokens=self.max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
        )
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process
        self.fps = kwargs.pop('fps', 2)
        self.nframe = kwargs.pop('nframe', 128)
        if self.fps is None and self.nframe is None:
            print("Warning: fps and nframe are both None, \
                  using default nframe/fps setting in qwen-vl-utils/qwen-omni-utils, \
                  the fps/nframe setting in video dataset is omitted")
        self.use_audio_in_video = use_audio_in_video
        self.FRAME_FACTOR = 2
        rank, world_size = get_rank_and_world_size()
        assert model_path is not None
        self.model_path = model_path
        MODEL_CLS = None

        if listinstr(['omni'], model_path.lower()):
            try:
                from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
            except Exception as err:
                logging.critical("pip install git+https://github.com/huggingface/transformers@3a1ead0aabed473eafe527915eea8c197d424356")  # noqa: E501
                raise err
            MODEL_CLS = Qwen2_5OmniForConditionalGeneration
            self.processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
        elif listinstr(['2.5', '2_5', 'qwen25'], model_path.lower()):
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
            MODEL_CLS = Qwen2_5_VLForConditionalGeneration
            self.processor = AutoProcessor.from_pretrained(model_path)
        else:
            from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
            MODEL_CLS = Qwen2VLForConditionalGeneration
            self.processor = Qwen2VLProcessor.from_pretrained(model_path)

        gpu_mems = get_gpu_memory()
        max_gpu_mem = max(gpu_mems) if gpu_mems != [] else -1
        assert max_gpu_mem > 0
        self.use_vllm = kwargs.get('use_vllm', False)
        self.use_lmdeploy = kwargs.get('use_lmdeploy', False)
        self.limit_mm_per_prompt = VLLM_MAX_IMAGE_INPUT_NUM
        assert self.use_vllm + self.use_lmdeploy <= 1, "You can only set one flag between `use_vllm` and `use_lmdeploy` to True"  # noqa: E501

        if self.use_vllm:
            from vllm import LLM
            gpu_count = setup_visible_devices_per_rank()
            if gpu_count >= 8:
                tp_size = 8
            elif gpu_count >= 4:
                tp_size = 4
            elif gpu_count >= 2:
                tp_size = 2
            else:
                tp_size = 1
            logging.info(
                f'Using vLLM for {self.model_path} inference with {tp_size} GPUs (available: {gpu_count})'
            )
            import os
            if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD') != 'spawn':
                logging.warning(
                    'VLLM_WORKER_MULTIPROC_METHOD is not set to spawn.'
                    'Use \'export VLLM_WORKER_MULTIPROC_METHOD=spawn\' to avoid potential multi-process issues'
                )
            self.llm = LLM(
                model=self.model_path,
                max_num_seqs=5,
                max_model_len=32768,
                limit_mm_per_prompt={"image": self.limit_mm_per_prompt},
                tensor_parallel_size=tp_size,
                gpu_memory_utilization=kwargs.get("gpu_utils", 0.9),
            )

        elif self.use_lmdeploy:
            from lmdeploy import TurbomindEngineConfig, pipeline, ChatTemplateConfig
            num_gpus = torch.cuda.device_count()
            self.model = pipeline(
                model_path,
                backend_config=TurbomindEngineConfig(session_len=32768, cache_max_entry_count=0.1, tp=num_gpus)
            )
            torch.cuda.set_device(0)
            self.device = 'cuda'
        else:
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map="auto", attn_implementation='flash_attention_2'
            )
            self.model.eval()

        torch.cuda.empty_cache()

    def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        content = []
        for s in inputs:
            if s['type'] == 'image':
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                if dataset == 'OCRBench':
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                if self.total_pixels is not None:
                    item['total_pixels'] = self.total_pixels
            elif s['type'] == 'video':
                item = {
                    'type': 'video',
                    'video': ensure_video_url(s['value'])
                }
                if self.min_pixels is not None:
                    item['min_pixels'] = self.min_pixels
                if self.max_pixels is not None:
                    item['max_pixels'] = self.max_pixels
                if self.total_pixels is not None:
                    item['total_pixels'] = self.total_pixels
                if self.fps is not None:
                    item['fps'] = self.fps
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        item['nframes'] = self.nframe
            elif s['type'] == 'text':
                item = {'type': 'text', 'text': s['value']}
            elif s['type'] == 'audio':
                item = {'type':'audio','audio':s['value']}
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content

    def _prepare_content_vllm(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        content = []
        video_inputs = [s for s in inputs if s['type'] == 'video']
        video_count = len(video_inputs)
        cur_image_count = 0
        for s in inputs:
            if s['type'] == 'image':
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                if dataset == 'OCRBench':
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                if self.total_pixels is not None:
                    item['total_pixels'] = self.total_pixels
                if cur_image_count < self.limit_mm_per_prompt:
                    content.append(item)
                    cur_image_count += 1
                else:
                    logging.warning(
                        f"Number of images exceeds the limit of {self.limit_mm_per_prompt}. "
                        f"Only the first {self.limit_mm_per_prompt} images will be used."
                    )
            elif s['type'] == 'video':
                if video_count > 1:
                    logging.warning(
                        "Multiple videos detected. Using video frames for each video"
                    )
                    if dataset == 'OCRBench':
                        min_pixels = 10 * 10 * 28 * 28
                        warnings.warn(f"OCRBench dataset uses custom min_pixels={min_pixels}")
                        if self.max_pixels is not None:
                            max_pixels = self.max_pixels
                    else:
                        if self.min_pixels is not None:
                            min_pixels = self.min_pixels
                        if self.max_pixels is not None:
                            max_pixels = self.max_pixels
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()

                    frames_per_video = max(1, self.limit_mm_per_prompt // video_count)
                    content.append({"type": "text", "text": "<video frames start>"})
                    content.extend(process_video(s['value'], frames_per_video, min_pixels, max_pixels))
                    content.append({"type": "text", "text": "<video frames end>"})

                else:
                    item = {
                        'type': 'video',
                        'video': ensure_video_url(s['value'])
                    }
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                    if self.total_pixels is not None:
                        item['total_pixels'] = self.total_pixels
                    if self.fps is not None:
                        item['fps'] = self.fps
                    elif self.nframe is not None:
                        import cv2
                        video = cv2.VideoCapture(s['value'])
                        frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                        video.release()
                        if frame_count < self.nframe:
                            new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                            print(f"use {new_frame_count} for {s['value']}")
                            item['nframes'] = new_frame_count
                        else:
                            item['nframes'] = self.nframe
                    content.append(item)
            elif s['type'] == 'text':
                item = {'type': 'text', 'text': s['value']}
                content.append(item)
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
        return content

    def generate_inner_transformers(self, message, dataset=None):
        if listinstr(['omni'], self.model_path.lower()):
            try:
                from qwen_omni_utils import process_mm_info
            except Exception as err:
                logging.critical("qwen_omni_utils not found, please install it via 'pip install qwen-omni-utils[decord]'")  # noqa: E501
                raise err
        else:
            try:
                from qwen_vl_utils import process_vision_info
            except Exception as err:
                logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")  # noqa: E501
                raise err

        messages = []
        if self.system_prompt is not None:
            messages.append({'role': 'system', 'content': self.system_prompt})
        messages.append({'role': 'user', 'content': self._prepare_content(message, dataset=dataset)})
        if self.verbose:
            print(f'\033[31m{messages}\033[0m')

        text = self.processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)
        if listinstr(['omni'], self.model_path.lower()):
            audios, images, videos = process_mm_info([messages], use_audio_in_video=self.use_audio_in_video)
            inputs = self.processor(text=text, images=images,audio=audios, videos=videos, padding=True, return_tensors='pt',use_audio_in_video=self.use_audio_in_video)  # noqa: E501
        else:
            images, videos = process_vision_info([messages])
            inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt')  # noqa: E501
        inputs = inputs.to('cuda')

        if listinstr(['omni'], self.model_path.lower()):
            self.generate_kwargs['use_audio_in_video'] = self.use_audio_in_video
            self.generate_kwargs['return_audio'] = False
        generated_ids = self.model.generate(
            **inputs,
            **self.generate_kwargs,
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        out = self.processor.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        response = out[0]
        if self.post_process:
            resp = response.split('\\boxed{')[-1]
            lt = len(resp)
            counter, end = 1, None
            for i in range(lt):
                if resp[i] == '{':
                    counter += 1
                elif resp[i] == '}':
                    counter -= 1
                if counter == 0:
                    end = i
                    break
                elif i == lt - 1:
                    end = lt
                    break
            if end is not None:
                response = resp[:end]

        if self.verbose:
            print(f'\033[32m{response}\033[0m')
        return response
    def generate_inner_lmdeploy(self, message, dataset=None):
        from lmdeploy import GenerationConfig
        gen_config = GenerationConfig(
            max_new_tokens=self.max_new_tokens,
            top_p=self.generate_kwargs['top_p'],
            top_k=self.generate_kwargs['top_k'],
            temperature=self.generate_kwargs['temperature'],
            repetition_penalty=self.generate_kwargs['repetition_penalty'],
        )
        gen_config.random_seed = None
        messages_list = self.message_to_lmdeploy(message, system_prompt=self.system_prompt)
        assert len(messages_list) == 1
        response = self.model(messages_list, gen_config=gen_config)[0]
        response = response.text
        return response

    def generate_inner_vllm(self, message, dataset=None):
        from vllm import SamplingParams

        if listinstr(['omni'], self.model_path.lower()):
            try:
                from qwen_omni_utils import process_mm_info
            except Exception as err:
                logging.critical("qwen_omni_utils not found, please install it via 'pip install qwen-omni-utils[decord]'")  # noqa: E501
                raise err
        else:
            try:
                from qwen_vl_utils import process_vision_info
            except Exception as err:
                logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")  # noqa: E501
                raise err

        messages = []
        if self.system_prompt is not None:
            messages.append({'role': 'system', 'content': self.system_prompt})
        messages.append({'role': 'user', 'content': self._prepare_content_vllm(message, dataset=dataset)})
        if self.verbose:
            print(f'\033[31m{messages}\033[0m')

        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        if listinstr(['omni'], self.model_path.lower()):
            audios, images, videos = process_mm_info(messages, use_audio_in_video=self.use_audio_in_video)
        else:
            images, videos = process_vision_info(messages)
        print('finishing process vision info in vllm.')

        if DATASET_MODALITY(dataset) == 'VIDEO':
            assert len(videos) == 1
            videos_nd = [videos[0].detach().cpu().numpy().transpose(0, 2, 3, 1)]

            video_inputs = {
                "prompt": text[0],
                "multi_modal_data": {"video": videos_nd[0]},
                "mm_processor_kwargs":{}
            }
            if self.use_audio_in_video:
                import vllm
                assert not vllm.envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. Please launch this example with `VLLM_USE_V1=0`.")  # noqa: E501
                video_inputs["multi_modal_data"]["audio"] = audios[0]
                video_inputs['mm_processor_kwargs']['use_audio_in_video'] = True
            if videos_nd[0].shape[0] > VLLM_MAX_IMAGE_INPUT_NUM:
                print('video input sequence may be too long for vllm, Maybe cannot generate response for VLLM')
        sampling_params = SamplingParams(
            temperature=0.0, max_tokens=self.max_new_tokens, stop_token_ids=None
        )
        if images:
            outputs = self.llm.generate(
                {
                    "prompt": text,
                    "multi_modal_data": {"image": images},
                },
                sampling_params=sampling_params,
            )
        elif videos_nd:
            outputs = self.llm.generate(
                video_inputs,
                sampling_params=sampling_params,
            )
        else:
            outputs = self.llm.generate(
                {
                    "prompt": text,
                },
                sampling_params=sampling_params,
            )

        for o in outputs:
            generated_text = o.outputs[0].text

        if self.post_process:
            resp = generated_text.split('\\boxed{')[-1]
            lt = len(resp)
            counter, end = 1, None
            for i in range(lt):
                if resp[i] == '{':
                    counter += 1
                elif resp[i] == '}':
                    counter -= 1
                if counter == 0:
                    end = i
                    break
                elif i == lt - 1:
                    end = lt
                    break
            if end is not None:
                generated_text = resp[:end]

        if self.verbose:
            print(f'\033[32m{generated_text}\033[0m')
        return generated_text

    def generate_inner(self, message, dataset=None):
        if self.use_vllm:
            return self.generate_inner_vllm(message, dataset=dataset)
        elif self.use_lmdeploy:
            return self.generate_inner_lmdeploy(message, dataset=dataset)
        else:
            return self.generate_inner_transformers(message, dataset=dataset)


class Qwen2VLChatAguvis(Qwen2VLChat):
    def __init__(self, mode=None, **kwargs):
        self.mode = mode
        super().__init__(**kwargs)
        self.processor.max_pixels = self.max_pixels
        self.processor.min_pixels = self.min_pixels

    def generate_inner(self, message, dataset=None):
        try:
            from qwen_vl_utils import process_vision_info
        except Exception as err:
            logging.critical(
                "qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'"
            )
            raise err

        messages = []
        user_message = []
        for item in message:
            if "role" in item.keys():
                if item["role"] == "system":
                    self.system_prompt = item["value"]
                else:
                    item.pop("role")
                    user_message.append(item)
            else:
                user_message.append(item)
        message = user_message

        if self.system_prompt is not None:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append(
            {"role": "user", "content": self._prepare_content(message, dataset=dataset)}
        )
        if self.verbose:
            print(f"\033[31m{messages}\033[0m")

        text = self.processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            chat_template=CHAT_TEMPLATE,
        )
        # TODO: provide current action's low-level instruction
        # if False:
        #     # If low-level instruction is provided
        #     # We enforce using "Action: {low_level_instruction} to guide generation"
        #     recipient_text = f"<|im_start|>assistant<|recipient|>all\nAction: {low_level_instruction}\n"
        if self.mode == "force-plan":
            recipient_text = "<|im_start|>assistant<|recipient|>all\nThought: "
        elif self.mode == "force-plan-l1":
            recipient_text = "<|im_start|>assistant<|recipient|>all\nAction: "
        elif self.mode == "force-plan-l3":
            recipient_text = "<|im_start|>assistant<|recipient|>all\nObservation: "
        elif self.mode == "grounding":
            recipient_text = "<|im_start|>assistant<|recipient|>os\n"
        elif self.mode == "force-plan-free":
            recipient_text = "<|im_start|>assistant<|recipient|>all\n"
        elif self.mode == "self-plan":
            recipient_text = "<|im_start|>assistant<|recipient|>"
        else:
            raise ValueError(f"Invalid mode: {self.mode}")
        text += recipient_text
        # print(text)

        images, videos = process_vision_info([messages])
        inputs = self.processor(
            text=[text], images=images, videos=videos, padding=True, return_tensors="pt"
        )
        inputs = inputs.to("cuda")

        # stop_str = "<|diff_marker|>"
        # keywords = [stop_str]
        # stopping_criteria = KeywordsStoppingCriteria(
        #     keywords, self.processor.tokenizer, inputs.input_ids
        # )

        generated_ids = self.model.generate(
            **inputs,
            **self.generate_kwargs,
            # stopping_criteria=[stopping_criteria],
        )
        generated_ids = [
            output_ids[len(input_ids):]
            for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        out = self.processor.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        response = out[0]
        # for term in UNTIL:
        #     if len(term) > 0:
        #         response = response.split(term)[0]

        if self.post_process:
            resp = response.split("\\boxed{")[-1]
            lt = len(resp)
            counter, end = 1, None
            for i in range(lt):
                if resp[i] == "{":
                    counter += 1
                elif resp[i] == "}":
                    counter -= 1
                if counter == 0:
                    end = i
                    break
                elif i == lt - 1:
                    end = lt
                    break
            if end is not None:
                response = resp[:end]

        if self.verbose:
            print(f"\033[32m{response}\033[0m")
        return response


class Qwen2VLCaptionQwen2LLM(Qwen2VLPromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        vl_model_path: str,  # VL模型路径
        llm_model_path: str,  # LLM模型路径
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        max_new_tokens=2048,
        llm_max_new_tokens=8192,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        use_custom_prompt: bool = True,
        system_prompt: str | None = None,
        post_process: bool = False,
        verbose: bool = False,
    ):
        super().__init__(use_custom_prompt=use_custom_prompt)
        
        # 初始化VL模型
        self.init_vl_model(vl_model_path, min_pixels, max_pixels, max_new_tokens, 
                          top_p, top_k, temperature, repetition_penalty)
        
        # 初始化LLM模型
        self.init_llm_model(llm_model_path, llm_max_new_tokens, top_p, top_k, temperature, repetition_penalty)
        
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process

    def init_llm_model(
        self, 
        model_path: str, 
        max_new_tokens=8192,
        # 核采样参数
        top_p=0.001,
        # k个最高概率的token
        top_k=1,
        # 温度参数，控制生成的随机性
        temperature=0.01,
        # 重复惩罚系数
        repetition_penalty=1.0,
        # 是否使用自定义提示
        use_custom_prompt: bool = True,
        # 系统提示词
        system_prompt: str | None = None,
        # 是否只提取最后一个\boxed{}中的内容
        post_process: bool = False,  # if True, will try to only extract stuff in the last \boxed{}.
        # 是否打印详细信息
        verbose: bool = False,
    ):
        """初始化Qwen LLM模型"""
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.llm_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map="cpu",
            attn_implementation="flash_attention_2"
        )
        self.llm_generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
        )
        self.llm_model.cuda().eval()

    def init_vl_model(
        self,
        # 模型路径
        model_path: str,
        # 最小像素数
        min_pixels: int | None = None,
        # 最大像素数
        max_pixels: int | None = None,
        # 生成的最大token数
        max_new_tokens=2048,
        # 核采样参数
        top_p=0.001,
        # k个最高概率的token
        top_k=1,
        # 温度参数，控制生成的随机性
        temperature=0.01,
        # 重复惩罚系数
        repetition_penalty=1.0,
        # 是否使用自定义提示
        use_custom_prompt: bool = True,
        # 系统提示词
        system_prompt: str | None = None,
        # 是否只提取最后一个\boxed{}中的内容
        post_process: bool = False,  # if True, will try to only extract stuff in the last \boxed{}.
        # 是否打印详细信息
        verbose: bool = False,
    ):
        """初始化Qwen VL模型"""
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.vl_generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
        )
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process
        self.fps = 2.0
        self.nframe = 64
        self.FRAME_FACTOR = 2
        rank, world_size = get_rank_and_world_size()
        assert model_path is not None
        self.model_path = model_path
        MODEL_CLS = None  

        # 根据模型路径判断使用哪个版本的模型
        if '2.5' in model_path:
            # 加载Qwen 2.5版本的模型和处理器
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
            MODEL_CLS = Qwen2_5_VLForConditionalGeneration
            self.processor = AutoProcessor.from_pretrained(model_path)
        else:
            # 加载Qwen 2.0版本的模型和处理器
            from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
            MODEL_CLS = Qwen2VLForConditionalGeneration
            self.processor = Qwen2VLProcessor.from_pretrained(model_path)

        gpu_mems = get_gpu_memory()
        max_gpu_mem = max(gpu_mems) if gpu_mems != [] else -1
        assert max_gpu_mem > 0

        # 72B大模型的特殊处理
        # If only one process and GPU memory is less than 40GB
        if '72b' in self.model_path.lower():
            # 使用split_model()函数进行模型分割
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
            )
            self.model.eval()
        # elif auto_split_flag():
        #     # 自动分割模式
        #     assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
        #     # Will Use All GPUs to run one model
        #     self.model = MODEL_CLS.from_pretrained(
        #         model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
        #     )
        else:
            # 默认模式：先加载到CPU，再转移到GPU
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
            )
            self.model.cuda().eval()

        torch.cuda.empty_cache()

    def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        # 处理输入内容，支持三种类型：
        # 1. 图片（image）：处理图片URL和像素限制
        # 2. 视频（video）：处理视频URL和帧率/帧数
        # 3. 文本（text）：直接处理文本内容
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        # 初始化结果列表
        content = []
        for s in inputs:
            # 处理图片类型数据
            if s['type'] == 'image':
                # 创建图片项，确保URL格式正确
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                # OCRBench数据集的特殊处理
                if dataset == 'OCRBench':
                    # 设置最小像素数为28x28的10x10倍
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    # 设置最大像素数
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    # 普通图片处理：设置最小和最大像素限制
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
            # 处理视频类型的数据
            elif s['type'] == 'video':
                # 创建视频项，确保URL格式正确
                item = {'type': 'video', 'video': ensure_video_url(s['value'])}
                if self.fps is not None:
                    # 设置帧率
                    item['fps'] = self.fps
                # 如果指定了帧数，则需要处理视频帧数
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    # 获取视频总帧数
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        # 如果实际帧数小于指定帧数，则按FRAME_FACTOR向下取整
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        # 否则使用指定的帧数
                        item['nframes'] = self.nframe
            elif s['type'] == 'text':
                # 直接创建文本项
                item = {'type': 'text', 'text': s['value']}
            else:
                # 如果输入类型既不是图片也不是视频也不是文本，抛出错误
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            # 将处理好的项添加到结果列表
            content.append(item)
        # 返回处理后的内容列表
        return content
    
    def generate_caption_requirement(self, question: str, hint: str = None, options: dict = None, image_count: int = 1) -> str:
        """使用LLM生成图像分析需求"""
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))

        # 构建提示词
        if use_chinese:
            prompt = (
                f"作为一个专业的计算机视觉分析专家，我需要你帮助分析{image_count}张图片来回答问题。\n\n"
                "请按照以下步骤进行分析：\n"
                "1. 问题分析：\n"
                "   - 理解问题的核心要求\n"
                "   - 确定需要关注的关键信息\n"
                "   - 识别问题类型（描述、比较、计数等）\n\n"
                "2. 视觉重点：\n"
                "   - 列出需要重点关注的图像区域\n"
                "   - 指出需要识别的具体视觉元素\n"
                "   - 确定需要分析的视觉特征或属性\n\n"
                "3. 分析策略：\n"
                "   - 提出观察和分析的具体步骤\n"
                "   - 说明需要关注的细节程度\n"
                "   - 如果涉及多张图片，说明如何进行对比\n\n"
                f"问题：{question}\n"
            )
            
            if hint:
                prompt += f"提示信息：{hint}\n"
            
            if options:
                prompt += "可选答案：\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
            
            prompt += "\n请生成一个详细的分析计划，说明如何通过观察图片来回答这个问题。"
            
            system_prompt = (
                "你是一个专业的计算机视觉分析专家。你的任务是：\n"
                "1. 仔细分析问题的需求\n"
                "2. 确定需要在图片中寻找的具体视觉元素\n"
                "3. 提供清晰的观察和分析步骤\n"
                "4. 确保分析计划能够帮助准确回答问题"
            )
        else:
            prompt = (
                f"As a professional computer vision analysis expert, I need your help to analyze {image_count} images to answer a question.\n\n"
                "Please follow these steps for analysis:\n"
                "1. Question Analysis:\n"
                "   - Understand the core requirements of the question\n"
                "   - Identify key information needed\n"
                "   - Recognize question type (description, comparison, counting, etc.)\n\n"
                "2. Visual Focus:\n"
                "   - List image regions that need special attention\n"
                "   - Point out specific visual elements to identify\n"
                "   - Determine visual features or attributes to analyze\n\n"
                "3. Analysis Strategy:\n"
                "   - Propose specific steps for observation and analysis\n"
                "   - Indicate required level of detail\n"
                "   - If multiple images involved, explain comparison approach\n\n"
                f"Question: {question}\n"
            )
            
            if hint:
                prompt += f"Hint: {hint}\n"
            
            if options:
                prompt += "Options:\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
            
            prompt += "\nPlease generate a detailed analysis plan explaining how to answer this question through image observation."
            
            system_prompt = (
                "You are a professional computer vision analysis expert. Your task is to:\n"
                "1. Carefully analyze the question requirements\n"
                "2. Identify specific visual elements to look for in the images\n"
                "3. Provide clear observation and analysis steps\n"
                "4. Ensure the analysis plan helps answer the question accurately"
            )


        # 构建消息格式
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        # 生成文本
        text = self.llm_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        requirement = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return requirement

    def generate_final_answer(self, question: str, caption: str, hint: str = None, options: dict = None) -> str:
        """使用LLM根据图像描述生成最终答案"""
        # 判断是否使用中文
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))
        
        if use_chinese:
            prompt = (
                "请基于以下信息，通过清晰的推理步骤来回答问题。\n\n"
                "推理步骤：\n"
                "1. 问题理解：\n"
                "   - 明确问题的核心要求\n"
                "   - 确定需要从图像描述中提取的关键信息\n\n"
                "2. 信息分析：\n"
                "   - 从图像描述中提取相关的视觉细节\n"
                "   - 将这些细节与问题要求对应\n"
                "   - 考虑提示信息（如果有）\n\n"
                "3. 推理过程：\n"
                "   - 基于提取的信息进行逻辑推理\n"
                "   - 解释推理的每个步骤\n"
                "   - 说明如何得出结论\n\n"
                f"问题：{question}\n\n"
                f"图像详细描述：{caption}\n"
            )
            
            if hint:
                prompt += f"\n提示信息：{hint}"
            
            if options:
                prompt += "\n可选答案：\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
                prompt += "\n请通过清晰的推理过程，说明为什么选择特定答案。"
            else:
                prompt += "\n请通过清晰的推理过程，得出完整的答案。"
                
            system_prompt = (
                "你是一个专业的视觉问答专家。你的回答应该：\n"
                "1. 展示清晰的推理过程\n"
                "2. 解释每个推理步骤\n"
                "3. 明确说明如何从图像描述得出结论\n"
                "4. 确保答案与问题紧密相关\n"
                "5. 使用图像描述中的具体细节支持你的结论"
            )
        else:
            prompt = (
                "Please answer the question through clear reasoning steps based on the following information.\n\n"
                "Reasoning Steps:\n"
                "1. Question Understanding:\n"
                "   - Clarify core requirements of the question\n"
                "   - Identify key information needed from image description\n\n"
                "2. Information Analysis:\n"
                "   - Extract relevant visual details from image description\n"
                "   - Map these details to question requirements\n"
                "   - Consider hint information (if any)\n\n"
                "3. Reasoning Process:\n"
                "   - Conduct logical reasoning based on extracted information\n"
                "   - Explain each step of reasoning\n"
                "   - Show how conclusions are reached\n\n"
                f"Question: {question}\n\n"
                f"Detailed Image Description: {caption}\n"
            )
            
            if hint:
                prompt += f"\nHint: {hint}"
            
            if options:
                prompt += "\nOptions:\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
                prompt += "\nPlease explain through clear reasoning process why you choose a specific answer."
            else:
                prompt += "\nPlease arrive at a complete answer through clear reasoning process."
                
            system_prompt = (
                "You are a professional visual QA expert. Your answer should:\n"
                "1. Demonstrate clear reasoning process\n"
                "2. Explain each reasoning step\n"
                "3. Clearly show how conclusions are drawn from image description\n"
                "4. Ensure answers are closely related to questions\n"
                "5. Use specific details from image description to support your conclusions"
            )


        # 构建消息格式
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        # 生成文本
        text = self.llm_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        answer = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return answer

    def generate_inner(self, message, dataset=None):
        """实现迭代式的视觉问答系统"""
        try:
            from qwen_vl_utils import process_vision_info
            import json
            import os
            from datetime import datetime
        except Exception as err:
            logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")
            raise err

        # 创建用于存储所有信息的字典
        qa_process = {
            "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
            "images": [],
            "question": None,
            "hint": None,
            "options": {},
            "process": [],
            "final_answer": None
        }

        # 提取问题、提示和选项
        question = None
        hint = None
        options = {}
        image_messages = []
        
        for msg in message:
            if msg['type'] == 'text':
                question = msg['value']
                qa_process["question"] = question
            elif msg['type'] in ['image', 'video']:
                image_messages.append(msg)
                qa_process["images"].append(msg['value'])
            elif msg.get('hint'):
                hint = msg['hint']
                qa_process["hint"] = hint
            elif msg.get('options'):
                options = msg['options']
                qa_process["options"] = options

        if not question:
            raise ValueError("No question found in message")
        
        image_count = len(image_messages)
        all_visual_info = []  # 存储所有视觉信息
        max_iterations = 2    # 最大迭代次数
        current_iteration = 0 # 当前迭代次数

        # 1. 第一轮：获取初始视觉信息
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))
        
        if use_chinese:
            initial_prompt = (
                "你是一个专业的视觉分析助手。请以JSON格式提供分析结果，包含以下两个部分：\n\n"
                "1. reasoning: 详细的分析依据，需要从以下角度进行分析：\n"
                "   - 图片中的主要视觉元素\n"
                "   - 这些元素之间的空间关系\n"
                "   - 任何可能对回答问题有帮助的细节\n"
                "   - 如果有多张图片，请分别描述并说明它们之间的关系\n"
                "2. answer: 对问题的直接回答\n\n"
                f"问题：{question}\n\n"
                "请按照以下格式返回结果：\n"
                "```json\n"
                "{\n"
                '    "reasoning": "详细的分析过程"\n' ## 先reasoning再answer
                '    "answer": "你的答案",\n'
                "}\n"
                "```"
            )
        else:
            initial_prompt = (
                "You are a professional visual analysis assistant. Please provide your analysis in JSON format with two parts:\n\n"
                "1. reasoning: Detailed analysis considering:\n"
                "   - Main visual elements in the image(s)\n"
                "   - Spatial relationships between these elements\n"
                "   - Any details that might help answer the question\n"
                "   - If multiple images, describe each and their relationships\n"
                "2. answer: Direct response to the question\n\n"
                f"Question: {question}\n\n"
                "Please return your response in this format:\n"
                "```json\n"
                "{\n"
                '    "reasoning": "detailed analysis process"\n'
                '    "answer": "your answer",\n'
                "}\n"
                "```"
            )

        vl_messages = [{'role': 'system', 'content': initial_prompt},
                      {'role': 'user', 'content': self._prepare_content(image_messages, dataset=dataset)}]
        
        text = self.processor.apply_chat_template(vl_messages, tokenize=False, add_generation_prompt=True)
        images, videos = process_vision_info([vl_messages])
        inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt').to('cuda')
        
        generated_ids = self.model.generate(**inputs, **self.vl_generate_kwargs)
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
        initial_visual_info = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # 尝试解析JSON格式的回答
        try:
        

            # 解析JSON部分
            visual_info_dict, _ = parse_json_text_with_remaining(initial_visual_info)
            vlm_initial_answer = visual_info_dict.get('answer', '')
            vlm_initial_reasoning = visual_info_dict.get('reasoning', '')
                
        except json.JSONDecodeError as e:
            print(e)
            if self.verbose:
                print("\033[33m警告：JSON解析失败，使用原始响应\033[0m")
            vlm_initial_answer = ''
            vlm_initial_reasoning = initial_visual_info

        # 将解析后的信息存储为字典格式
        parsed_visual_info = {
            'answer': vlm_initial_answer,
            'reasoning': vlm_initial_reasoning
        }
        all_visual_info.append(parsed_visual_info)

        if self.verbose:
            print(f"\033[34m初始视觉信息：\n答案：{vlm_initial_answer}\n推理过程：{vlm_initial_reasoning}\033[0m")

        # 更新qa_process中的记录方式
        qa_process["process"].append({
            "step": "initial_vision",
            "content": {
                "answer": vlm_initial_answer,
                "reasoning": vlm_initial_reasoning
            }
        })

        # 初始化历史分析字符串
        analysis_history = ""

        # 修改迭代过程
        while current_iteration < max_iterations:
            iteration_info = {
                "iteration": current_iteration + 1,
                "llm_question": None,
                "vlm_answer": None
            }

            # 构建LLM提示
            if use_chinese:
                llm_prompt = (
                    f"问题：{question}\n\n"
                    f"初始推理过程：{vlm_initial_reasoning}\n"
                    f"初始答案：{vlm_initial_answer}\n"
                    f"{analysis_history}\n"  # 添加累积的分析历史
                    "请对上述答案和推理过程进行质疑，提出3个需要VLM进行更细致图像分析的问题。\n\n"
                    "请按照以下JSON格式返回问题：\n"
                    "```json\n"
                    "{\n"
                    '    "question1": "第一个问题",\n'
                    '    "question2": "第二个问题",\n'
                    '    "question3": "第三个问题"\n'
                    "}\n"
                    "```"
                )
            else:
                llm_prompt = (
                    f"Question: {question}\n\n"
                    f"Initial Reasoning: {vlm_initial_reasoning}\n"
                    f"Initial Answer: {vlm_initial_answer}\n"
                    f"{analysis_history}\n"  # 添加累积的分析历史
                    "Please question the above answer and reasoning process, and raise 3 questions that require VLM to conduct more detailed image analysis.\n\n"
                    "Please return the questions in the following JSON format:\n"
                    "```json\n"
                    "{\n"
                    '    "question1": "first question",\n'
                    '    "question2": "second question",\n'
                    '    "question3": "third question"\n'
                    "}\n" ### ```json ```匹配
                    "```"
                )

            # 使用LLM生成追问
            messages = [{"role": "user", "content": llm_prompt}]
            text = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
            
            generated_ids = self.llm_model.generate(
                **inputs, **self.llm_generate_kwargs
            )
            generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
            llm_response = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

            questions = []
            # 尝试提取推理过程和JSON部分
            try:
                # 查找最后一个JSON结构的位置
                json_start = llm_response.rfind('{')
                
                if json_start != -1:
                    # 分离推理过程和JSON部分
                    thought_process = llm_response[:json_start].strip()

                     # 解析JSON部分
                    questions_dict, _ = parse_json_text_with_remaining(llm_response)
                    
                    # 获取问题列表
                    if isinstance(questions_dict, dict):
                        questions = [
                            questions_dict.get('question1', ''),
                            questions_dict.get('question2', ''),
                            questions_dict.get('question3', '')
                        ]
                    elif isinstance(questions_dict, list):
                        questions = []
                        for question_item in questions_dict:
                            if isinstance(question_item, dict):
                                questions = [
                                    question_item.get('question1', ''),
                                    question_item.get('question2', ''),
                                    question_item.get('question3', '')
                                ]
                        
                    else:
                        questions = []
                    
                    # 过滤掉空问题
                    questions = [q for q in questions if q]
                    
                    if self.verbose:
                        print(f"\033[34m推理过程：\n{thought_process}\033[0m")
                        for i, q in enumerate(questions, 1):
                            print(f"\033[36m问题{i}：{q}\033[0m")
                    
                    # 将问题合并为一个字符串，用于后续处理
                    llm_response = "\n".join(questions)
                    
                else:
                    if self.verbose:
                        print("\033[33m警告：未找到完整的JSON结构，使用原始响应\033[0m")
                    
            except json.JSONDecodeError:
                if self.verbose:
                    print("\033[33m警告：JSON解析失败，使用原始响应\033[0m")

            if self.verbose:
                print(f"\033[36m第{current_iteration + 1}轮LLM追问：\n{llm_response}\033[0m")
            
            iteration_info["llm_question"] = llm_response

            # 3. 使用VLM回答追问
            # 获取已解析的问题列表
            if not questions:  # 如果没有成功解析出问题，跳过这一轮
                if self.verbose:
                    print("\033[33m警告：本轮未获取到有效问题，跳过\033[0m")
                continue

            follow_up_answers = []  # 存储每个问题的回答
            
            # 对每个问题分别进行VLM分析
            for i, sub_question in enumerate(questions, 1):
                if use_chinese:
                    follow_up_prompt = (
                        "你是一个专业的视觉分析助手。请以JSON格式提供分析结果，包含以下两个部分：\n\n"
                        "1. reasoning: 详细的分析依据，需要从以下角度进行分析：\n"
                        "   - 图片中的主要视觉元素\n"
                        "   - 这些元素之间的空间关系\n"
                        "   - 任何可能对回答问题有帮助的细节\n"
                        "   - 如果有多张图片，请分别描述并说明它们之间的关系\n"
                        "2. answer: 对问题的直接回答\n\n"
                        f"问题：{sub_question}\n\n"
                        "请按照以下格式返回结果：\n"
                        "```json\n"
                        "{\n"
                        '    "reasoning": "详细的分析过程"\n'
                        '    "answer": "你的答案",\n'
                        "}\n"
                        "```"
                    )
                else:
                    follow_up_prompt = (
                        "You are a professional visual analysis assistant. Please provide your analysis in JSON format with two parts:\n\n"
                        "1. reasoning: Detailed analysis considering:\n"
                        "   - Main visual elements in the image(s)\n"
                        "   - Spatial relationships between these elements\n"
                        "   - Any details that might help answer the question\n"
                        "   - If multiple images, describe each and their relationships\n"
                        "2. answer: Direct response to the question\n\n"
                        f"Question: {sub_question}\n\n"
                        "Please return your response in this format:\n"
                        "```json\n"
                        "{\n"
                        '    "reasoning": "detailed analysis process"\n'
                        '    "answer": "your answer",\n'
                        "}\n"
                        "```"
                    )

                vl_messages = [{'role': 'system', 'content': follow_up_prompt},
                             {'role': 'user', 'content': self._prepare_content(image_messages, dataset=dataset)}]
                
                text = self.processor.apply_chat_template(vl_messages, tokenize=False, add_generation_prompt=True)
                images, videos = process_vision_info([vl_messages])
                inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt').to('cuda')
                
                generated_ids = self.model.generate(**inputs, **self.vl_generate_kwargs)
                generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
                follow_up_info = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

                # 尝试解析JSON格式的回答
                try:
                    # 解析JSON部分
                    response_dict, _ = parse_json_text_with_remaining(follow_up_info)
                    vlm_followed_answer = response_dict.get('answer', '')
                    vlm_followed_reasoning = response_dict.get('reasoning', '')
                    
                        
                except json.JSONDecodeError:
                    if self.verbose:
                        print(f"\033[33m警告：问题{i}的回答JSON解析失败，使用原始响应\033[0m")
                    vlm_followed_answer = ''
                    vlm_followed_reasoning = follow_up_info

                follow_up_answers.append({
                    'focus_point': sub_question,
                    'answer': vlm_followed_answer,
                    'reasoning': vlm_followed_reasoning
                })
                
                if self.verbose:
                    print(f"\033[35m关注点{i}的分析：\关注的问题{sub_question}\n推理：{vlm_followed_reasoning}\n答案：{vlm_followed_answer}\033[0m")

            # 将本次分析添加到当前轮次的字符串中
            current_round_analysis = (
                f"\n第{current_iteration + 1}轮分析：\n" if use_chinese else f"\nRound {current_iteration + 1} Analysis:\n"
            )
            
            # 对每个问题进行VLM分析
            for i, sub_question in enumerate(questions, 1):
                # 将本次分析添加到当前轮次的字符串中
                current_round_analysis += (
                    f"关注点{i}：{sub_question}\n" if use_chinese else f"Focus Point {i}: {sub_question}\n"
                )
                current_round_analysis += (
                    f"推理过程：{vlm_followed_reasoning}\n"
                    f"分析结果：{vlm_followed_answer}\n\n"
                    if use_chinese else
                    f"Reasoning Process: {vlm_followed_reasoning}\n"
                    f"Analysis Result: {vlm_followed_answer}\n\n"
                )
                # current_round_analysis += (
                #     f"分析结果：{follow_up_info}\n\n" if use_chinese else f"Analysis Result: {follow_up_info}\n\n"
                # )

            # 更新iteration_info和qa_process
            iteration_info["vlm_answer"] = follow_up_answers
            qa_process["process"].append(iteration_info)
            
            # 更新视觉信息，将所有回答合并
            all_visual_info.extend(follow_up_answers)

            # 更新历史分析字符串
            analysis_history += current_round_analysis

            current_iteration += 1

        if self.verbose:
            print("\033[33m=== 开始生成最终答案 ===\033[0m")

        # 4. 整合所有视觉信息并生成最终答案
        combined_visual_info = "\n\n".join([f"{info['answer']}\n{info['reasoning']}" for info in all_visual_info])
        
        if use_chinese:
            final_prompt = (
                "请基于以下所有视觉信息，通过清晰的推理步骤来回答问题。\n\n"
                "推理步骤：\n"
                "1. 信息整合：\n"
                "   - 整理所有视觉观察结果\n"
                "   - 识别关键信息点\n"
                "   - 建立信息之间的联系\n\n"
                "2. 逻辑推理：\n"
                "   - 基于整合的信息进行推理\n"
                "   - 解释推理过程\n"
                "   - 说明结论的可靠性\n\n"
                f"问题：{question}\n\n"
                f"初始推理过程：{vlm_initial_reasoning}\n"
                f"初始答案：{vlm_initial_answer}\n"
                f"{analysis_history}\n"  # 添加累积的分析历史
                # f"问题：{question}\n\n"
                # f"所有视觉信息：\n{combined_visual_info}\n"
            )
        else:
            final_prompt = (
                "Please answer the question through clear reasoning steps based on all the following visual information.\n\n"
                "Reasoning Steps:\n"
                "1. Information Integration:\n"
                "   - Organize all visual observations\n"
                "   - Identify key information points\n"
                "   - Establish connections between information\n\n"
                "2. Logical Reasoning:\n"
                "   - Conduct reasoning based on integrated information\n"
                "   - Explain the reasoning process\n"
                "   - Demonstrate conclusion reliability\n\n"
                f"Question: {question}\n\n"
                f"Initial Reasoning: {vlm_initial_reasoning}\n"
                f"Initial Answer: {vlm_initial_answer}\n"
                f"{analysis_history}\n"  # 添加累积的分析历史
                # f"Question: {question}\n\n"
                # f"All Visual Information:\n{combined_visual_info}\n"
            )

        if hint:
            final_prompt += f"\n提示信息：{hint}" if use_chinese else f"\nHint: {hint}"
        
        if options:
            if use_chinese:
                final_prompt += "\n可选答案：\n"
                for key, value in options.items():
                    final_prompt += f"{key}. {value}\n"
                final_prompt += "\n请通过清晰的推理过程，说明为什么选择特定答案。"
            else:
                final_prompt += "\nOptions:\n"
                for key, value in options.items():
                    final_prompt += f"{key}. {value}\n"
                final_prompt += "\nPlease explain through clear reasoning process why you choose a specific answer."

        # 生成最终答案
        messages = [{"role": "user", "content": final_prompt}]
        text = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            **self.llm_generate_kwargs
        )
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
        final_answer = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # 记录最终答案
        qa_process["final_answer"] = final_answer

        if self.verbose:
            print(f"\033[34m初始视觉信息：{all_visual_info[0]}\033[0m")
            for i, info in enumerate(all_visual_info[1:], 1):
                print(f"\033[35m第{i}轮补充信息：{info}\033[0m")
            print(f"\033[33m最终答案：{final_answer}\033[0m")

        # 保存问答过程到JSON文件
        os.makedirs('qa_logs', exist_ok=True)
        json_filename = f'qa_logs/qa_process_{qa_process["timestamp"]}.json'
        with open(json_filename, 'w', encoding='utf-8') as f:
            json.dump(qa_process, f, ensure_ascii=False, indent=2)

        return final_answer
