from __future__ import annotations

import os
import sys
import warnings
import math
import logging

import torch

from ..base import BaseModel
from .prompt import Qwen2VLPromptMixin
from ...smp import get_rank_and_world_size, get_gpu_memory, auto_split_flag


def ensure_image_url(image: str) -> str:
    # 定义合法的URL前缀列表
    prefixes = ['http://', 'https://', 'file://', 'data:image;']
    # 检查图片路径是否已经是合法URL格式
    if any(image.startswith(prefix) for prefix in prefixes):
        return image
    # 如果是本地文件路径，转换为file://格式
    if os.path.exists(image):
        return 'file://' + image
    # 如果都不符合，抛出错误
    raise ValueError(f'Invalid image: {image}')


def ensure_video_url(video: str) -> str:
    # 与ensure_image_url类似，但是处理视频URL
    # 唯一的区别是data:前缀使用data:video;而不是data:image;
    prefixes = ['http://', 'https://', 'file://', 'data:video;']
    if any(video.startswith(prefix) for prefix in prefixes):
        return video
    if os.path.exists(video):
        return 'file://' + video
    raise ValueError(f'Invalid video: {video}')


def split_model():
    # 创建设备映射字典
    device_map = {}

    # 获取GPU相关信息
    # 获取总GPU数量
    total_gpus = torch.cuda.device_count()
    # 获取当前进程rank和总进程数
    rank, world_size = get_rank_and_world_size()
    # 计算每个进程可用的GPU数量
    num_gpus = total_gpus // world_size
    # + 8 is virtual layers for the memory of visual
    # 计算模型层的分配
    # 总层数（80个标准层 + 8个虚拟层用于视觉处理）
    num_layers = 80 + 8
    # 计算每个GPU应该分配的层数
    num_layers_per_gpu = math.ceil(num_layers / num_gpus)
    # 创建每个GPU的层数列表
    num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
    # 第一个GPU减少6层
    num_layers_per_gpu[0] -= 6
    # 最后一个GPU减少2层
    num_layers_per_gpu[-1] -= 2
    # 分配模型层到不同GPU
    layer_cnt = 0

    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'model.layers.{layer_cnt}'] = rank + i * world_size
            layer_cnt += 1

    # 分配其他模型组件
    # 计算最后一个GPU的编号
    last_gpu = rank + (num_gpus - 1) * world_size
    # 视觉模块放在第一个GPU
    device_map['visual'] = rank
    # token嵌入层放在第一个GPU
    device_map['model.embed_tokens'] = rank
    # 归一化层放在最后一个GPU
    device_map['model.norm'] = last_gpu
    # 旋转位置编码放在最后一个GPU
    device_map['model.rotary_emb'] = last_gpu
    # 语言模型头部放在最后一个GPU
    device_map['lm_head'] = last_gpu
    # 返回设备映射字典
    return device_map

# 检查是否包含中文字符
def cn_string(s):
    import re
    # 检查字符串是否包含中文字符（Unicode范围：\u4e00-\u9fff）
    if re.search('[\u4e00-\u9fff]', s):
        return True
    return False


class Qwen2VLChat(Qwen2VLPromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        model_path: str,
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        max_new_tokens=2048,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        use_custom_prompt: bool = True,
        system_prompt: str | None = None,
        post_process: bool = False,  # if True, will try to only extract stuff in the last \boxed{}.
        verbose: bool = False,
    ):
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
        )
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process
        self.fps = 2.0
        self.nframe = 64
        self.FRAME_FACTOR = 2
        rank, world_size = get_rank_and_world_size()
        assert model_path is not None
        self.model_path = model_path
        MODEL_CLS = None

        if listinstr(['2.5', '2_5', 'qwen25'], model_path.lower()):
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
            MODEL_CLS = Qwen2_5_VLForConditionalGeneration
            self.processor = AutoProcessor.from_pretrained(model_path)
        else:
            from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
            MODEL_CLS = Qwen2VLForConditionalGeneration
            self.processor = Qwen2VLProcessor.from_pretrained(model_path)

        gpu_mems = get_gpu_memory()
        max_gpu_mem = max(gpu_mems) if gpu_mems != [] else -1
        assert max_gpu_mem > 0

        # If only one process and GPU memory is less than 40GB
        if '72b' in self.model_path.lower():
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
            )
            self.model.eval()
        elif auto_split_flag():
            assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
            # Will Use All GPUs to run one model
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
            )
        else:
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
            )
            self.model.cuda().eval()

        torch.cuda.empty_cache()

    def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        content = []
        for s in inputs:
            if s['type'] == 'image':
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                if dataset == 'OCRBench':
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
            elif s['type'] == 'video':
                item = {'type': 'video', 'video': ensure_video_url(s['value'])}
                if self.fps is not None:
                    item['fps'] = self.fps
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        item['nframes'] = self.nframe
            elif s['type'] == 'text':
                item = {'type': 'text', 'text': s['value']}
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content

    def generate_inner(self, message, dataset=None):
        try:
            from qwen_vl_utils import process_vision_info
        except Exception as err:
            logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")
            raise err

        messages = []
        if self.system_prompt is not None:
            messages.append({'role': 'system', 'content': self.system_prompt})
        messages.append({'role': 'user', 'content': self._prepare_content(message, dataset=dataset)})
        if self.verbose:
            print(f'\033[31m{messages}\033[0m')

        text = self.processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)
        images, videos = process_vision_info([messages])
        inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt')
        inputs = inputs.to('cuda')

        generated_ids = self.model.generate(
            **inputs,
            **self.generate_kwargs,
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        out = self.processor.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        response = out[0]
        if self.post_process:
            resp = response.split('\\boxed{')[-1]
            lt = len(resp)
            counter, end = 1, None
            for i in range(lt):
                if resp[i] == '{':
                    counter += 1
                elif resp[i] == '}':
                    counter -= 1
                if counter == 0:
                    end = i
                    break
                elif i == lt - 1:
                    end = lt
                    break
            if end is not None:
                response = resp[:end]

        if self.verbose:
            print(f'\033[32m{response}\033[0m')
        return response

class Qwen2VLCaptionQwen2LLM(Qwen2VLPromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        vl_model_path: str,  # VL模型路径
        llm_model_path: str,  # LLM模型路径
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        max_new_tokens=2048,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        use_custom_prompt: bool = True,
        system_prompt: str | None = None,
        post_process: bool = False,
        verbose: bool = False,
    ):
        super().__init__(use_custom_prompt=use_custom_prompt)
        
        # 初始化VL模型
        self.init_vl_model(vl_model_path, min_pixels, max_pixels, max_new_tokens, 
                          top_p, top_k, temperature, repetition_penalty)
        
        # 初始化LLM模型
        self.init_llm_model(llm_model_path)
        
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process

    def init_llm_model(self, model_path: str):
        """初始化Qwen LLM模型"""
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.llm_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map="cpu",
            attn_implementation="flash_attention_2"
        )
        self.llm_model.cuda().eval()

    def init_vl_model(
        self,
        # 模型路径
        model_path: str,
        # 最小像素数
        min_pixels: int | None = None,
        # 最大像素数
        max_pixels: int | None = None,
        # 生成的最大token数
        max_new_tokens=2048,
        # 核采样参数
        top_p=0.001,
        # k个最高概率的token
        top_k=1,
        # 温度参数，控制生成的随机性
        temperature=0.01,
        # 重复惩罚系数
        repetition_penalty=1.0,
        # 是否使用自定义提示
        use_custom_prompt: bool = True,
        # 系统提示词
        system_prompt: str | None = None,
        # 是否只提取最后一个\boxed{}中的内容
        post_process: bool = False,  # if True, will try to only extract stuff in the last \boxed{}.
        # 是否打印详细信息
        verbose: bool = False,
    ):
        """初始化Qwen VL模型"""
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
        )
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.post_process = post_process
        self.fps = 2.0
        self.nframe = 64
        self.FRAME_FACTOR = 2
        rank, world_size = get_rank_and_world_size()
        assert model_path is not None
        self.model_path = model_path
        MODEL_CLS = None  

        # 根据模型路径判断使用哪个版本的模型
        if '2.5' in model_path:
            # 加载Qwen 2.5版本的模型和处理器
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
            MODEL_CLS = Qwen2_5_VLForConditionalGeneration
            self.processor = AutoProcessor.from_pretrained(model_path)
        else:
            # 加载Qwen 2.0版本的模型和处理器
            from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
            MODEL_CLS = Qwen2VLForConditionalGeneration
            self.processor = Qwen2VLProcessor.from_pretrained(model_path)

        gpu_mems = get_gpu_memory()
        max_gpu_mem = max(gpu_mems) if gpu_mems != [] else -1
        assert max_gpu_mem > 0

        # 72B大模型的特殊处理
        # If only one process and GPU memory is less than 40GB
        if '72b' in self.model_path.lower():
            # 使用split_model()函数进行模型分割
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
            )
            self.model.eval()
        elif auto_split_flag():
            # 自动分割模式
            assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
            # Will Use All GPUs to run one model
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
            )
        else:
            # 默认模式：先加载到CPU，再转移到GPU
            self.model = MODEL_CLS.from_pretrained(
                model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
            )
            self.model.cuda().eval()

        torch.cuda.empty_cache()

    def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
        # 处理输入内容，支持三种类型：
        # 1. 图片（image）：处理图片URL和像素限制
        # 2. 视频（video）：处理视频URL和帧率/帧数
        # 3. 文本（text）：直接处理文本内容
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        # 初始化结果列表
        content = []
        for s in inputs:
            # 处理图片类型数据
            if s['type'] == 'image':
                # 创建图片项，确保URL格式正确
                item = {'type': 'image', 'image': ensure_image_url(s['value'])}
                # OCRBench数据集的特殊处理
                if dataset == 'OCRBench':
                    # 设置最小像素数为28x28的10x10倍
                    item['min_pixels'] = 10 * 10 * 28 * 28
                    warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
                    # 设置最大像素数
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
                else:
                    # 普通图片处理：设置最小和最大像素限制
                    if self.min_pixels is not None:
                        item['min_pixels'] = self.min_pixels
                    if self.max_pixels is not None:
                        item['max_pixels'] = self.max_pixels
            # 处理视频类型的数据
            elif s['type'] == 'video':
                # 创建视频项，确保URL格式正确
                item = {'type': 'video', 'video': ensure_video_url(s['value'])}
                if self.fps is not None:
                    # 设置帧率
                    item['fps'] = self.fps
                # 如果指定了帧数，则需要处理视频帧数
                elif self.nframe is not None:
                    import cv2
                    video = cv2.VideoCapture(s['value'])
                    # 获取视频总帧数
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        # 如果实际帧数小于指定帧数，则按FRAME_FACTOR向下取整
                        new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR
                        print(f"use {new_frame_count} for {s['value']}")
                        item['nframes'] = new_frame_count
                    else:
                        # 否则使用指定的帧数
                        item['nframes'] = self.nframe
            elif s['type'] == 'text':
                # 直接创建文本项
                item = {'type': 'text', 'text': s['value']}
            else:
                # 如果输入类型既不是图片也不是视频也不是文本，抛出错误
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            # 将处理好的项添加到结果列表
            content.append(item)
        # 返回处理后的内容列表
        return content
    
    def generate_caption_requirement(self, question: str, hint: str = None, options: dict = None, image_count: int = 1) -> str:
        """使用LLM生成图像分析需求"""
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))

        # 构建提示词
        if use_chinese:
            prompt = (
                f"作为一个专业的计算机视觉分析专家，我需要你帮助分析{image_count}张图片来回答问题。\n\n"
                "请按照以下步骤进行分析：\n"
                "1. 问题分析：\n"
                "   - 理解问题的核心要求\n"
                "   - 确定需要关注的关键信息\n"
                "   - 识别问题类型（描述、比较、计数等）\n\n"
                "2. 视觉重点：\n"
                "   - 列出需要重点关注的图像区域\n"
                "   - 指出需要识别的具体视觉元素\n"
                "   - 确定需要分析的视觉特征或属性\n\n"
                "3. 分析策略：\n"
                "   - 提出观察和分析的具体步骤\n"
                "   - 说明需要关注的细节程度\n"
                "   - 如果涉及多张图片，说明如何进行对比\n\n"
                f"问题：{question}\n"
            )
            
            if hint:
                prompt += f"提示信息：{hint}\n"
            
            if options:
                prompt += "可选答案：\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
            
            prompt += "\n请生成一个详细的分析计划，说明如何通过观察图片来回答这个问题。"
            
            system_prompt = (
                "你是一个专业的计算机视觉分析专家。你的任务是：\n"
                "1. 仔细分析问题的需求\n"
                "2. 确定需要在图片中寻找的具体视觉元素\n"
                "3. 提供清晰的观察和分析步骤\n"
                "4. 确保分析计划能够帮助准确回答问题"
            )
        else:
            prompt = (
                f"As a professional computer vision analysis expert, I need your help to analyze {image_count} images to answer a question.\n\n"
                "Please follow these steps for analysis:\n"
                "1. Question Analysis:\n"
                "   - Understand the core requirements of the question\n"
                "   - Identify key information needed\n"
                "   - Recognize question type (description, comparison, counting, etc.)\n\n"
                "2. Visual Focus:\n"
                "   - List image regions that need special attention\n"
                "   - Point out specific visual elements to identify\n"
                "   - Determine visual features or attributes to analyze\n\n"
                "3. Analysis Strategy:\n"
                "   - Propose specific steps for observation and analysis\n"
                "   - Indicate required level of detail\n"
                "   - If multiple images involved, explain comparison approach\n\n"
                f"Question: {question}\n"
            )
            
            if hint:
                prompt += f"Hint: {hint}\n"
            
            if options:
                prompt += "Options:\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
            
            prompt += "\nPlease generate a detailed analysis plan explaining how to answer this question through image observation."
            
            system_prompt = (
                "You are a professional computer vision analysis expert. Your task is to:\n"
                "1. Carefully analyze the question requirements\n"
                "2. Identify specific visual elements to look for in the images\n"
                "3. Provide clear observation and analysis steps\n"
                "4. Ensure the analysis plan helps answer the question accurately"
            )


        # 构建消息格式
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        # 生成文本
        text = self.llm_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        requirement = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return requirement

    def generate_final_answer(self, question: str, caption: str, hint: str = None, options: dict = None) -> str:
        """使用LLM根据图像描述生成最终答案"""
        # 判断是否使用中文
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))
        
        if use_chinese:
            prompt = (
                "请基于以下信息，通过清晰的推理步骤来回答问题。\n\n"
                "推理步骤：\n"
                "1. 问题理解：\n"
                "   - 明确问题的核心要求\n"
                "   - 确定需要从图像描述中提取的关键信息\n\n"
                "2. 信息分析：\n"
                "   - 从图像描述中提取相关的视觉细节\n"
                "   - 将这些细节与问题要求对应\n"
                "   - 考虑提示信息（如果有）\n\n"
                "3. 推理过程：\n"
                "   - 基于提取的信息进行逻辑推理\n"
                "   - 解释推理的每个步骤\n"
                "   - 说明如何得出结论\n\n"
                f"问题：{question}\n\n"
                f"图像详细描述：{caption}\n"
            )
            
            if hint:
                prompt += f"\n提示信息：{hint}"
            
            if options:
                prompt += "\n可选答案：\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
                prompt += "\n请通过清晰的推理过程，说明为什么选择特定答案。"
            else:
                prompt += "\n请通过清晰的推理过程，得出完整的答案。"
                
            system_prompt = (
                "你是一个专业的视觉问答专家。你的回答应该：\n"
                "1. 展示清晰的推理过程\n"
                "2. 解释每个推理步骤\n"
                "3. 明确说明如何从图像描述得出结论\n"
                "4. 确保答案与问题紧密相关\n"
                "5. 使用图像描述中的具体细节支持你的结论"
            )
        else:
            prompt = (
                "Please answer the question through clear reasoning steps based on the following information.\n\n"
                "Reasoning Steps:\n"
                "1. Question Understanding:\n"
                "   - Clarify core requirements of the question\n"
                "   - Identify key information needed from image description\n\n"
                "2. Information Analysis:\n"
                "   - Extract relevant visual details from image description\n"
                "   - Map these details to question requirements\n"
                "   - Consider hint information (if any)\n\n"
                "3. Reasoning Process:\n"
                "   - Conduct logical reasoning based on extracted information\n"
                "   - Explain each step of reasoning\n"
                "   - Show how conclusions are reached\n\n"
                f"Question: {question}\n\n"
                f"Detailed Image Description: {caption}\n"
            )
            
            if hint:
                prompt += f"\nHint: {hint}"
            
            if options:
                prompt += "\nOptions:\n"
                for key, value in options.items():
                    prompt += f"{key}. {value}\n"
                prompt += "\nPlease explain through clear reasoning process why you choose a specific answer."
            else:
                prompt += "\nPlease arrive at a complete answer through clear reasoning process."
                
            system_prompt = (
                "You are a professional visual QA expert. Your answer should:\n"
                "1. Demonstrate clear reasoning process\n"
                "2. Explain each reasoning step\n"
                "3. Clearly show how conclusions are drawn from image description\n"
                "4. Ensure answers are closely related to questions\n"
                "5. Use specific details from image description to support your conclusions"
            )


        # 构建消息格式
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        # 生成文本
        text = self.llm_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
        ]
        answer = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return answer

    def generate_inner(self, message, dataset=None):
        """实现迭代式的视觉问答系统"""
        try:
            from qwen_vl_utils import process_vision_info
        except Exception as err:
            logging.critical("qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'")
            raise err

        # 提取问题、提示和选项
        question = None
        hint = None
        options = {}
        image_messages = []
        
        for msg in message:
            if msg['type'] == 'text':
                question = msg['value']
            elif msg['type'] in ['image', 'video']:
                image_messages.append(msg)
            elif msg.get('hint'):
                hint = msg['hint']
            elif msg.get('options'):
                options = msg['options']

        if not question:
            raise ValueError("No question found in message")
        
        image_count = len(image_messages)
        all_visual_info = []  # 存储所有视觉信息
        max_iterations = 2    # 最大迭代次数
        current_iteration = 0 # 当前迭代次数

        # 1. 第一轮：获取初始视觉信息
        use_chinese = cn_string(question) or (hint and cn_string(hint)) or (options and any(cn_string(v) for v in options.values()))
        
        if use_chinese:
            initial_prompt = (
                "你是一个专业的视觉分析助手。请仔细观察图片并描述：\n"
                "1. 图片中的主要视觉元素\n"
                "2. 这些元素之间的空间关系\n"
                "3. 任何可能对回答问题有帮助的细节\n"
                "4. 如果有多张图片，请分别描述并说明它们之间的关系\n"
                f"\n问题：{question}"
            )
        else:
            initial_prompt = (
                "You are a professional visual analysis assistant. Please carefully observe the images and describe:\n"
                "1. Main visual elements in the images\n"
                "2. Spatial relationships between these elements\n"
                "3. Any details that might help answer the question\n"
                "4. If there are multiple images, describe each one and their relationships\n"
                f"\nQuestion: {question}"
            )

        vl_messages = [{'role': 'system', 'content': initial_prompt},
                      {'role': 'user', 'content': self._prepare_content(image_messages, dataset=dataset)}]
        
        text = self.processor.apply_chat_template(vl_messages, tokenize=False, add_generation_prompt=True)
        images, videos = process_vision_info([vl_messages])
        inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt').to('cuda')
        
        generated_ids = self.model.generate(**inputs, **self.generate_kwargs)
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
        initial_visual_info = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        all_visual_info.append(initial_visual_info)

        # 2. 迭代过程：LLM分析视觉信息并提出新的问题
        while current_iteration < max_iterations:
            # 构建LLM提示
            if use_chinese:
                llm_prompt = (
                    f"基于以下视觉信息，请分析是否存在需要进一步确认的关键点：\n\n"
                    f"问题：{question}\n\n"
                    f"当前视觉信息：\n{initial_visual_info}\n\n"
                    "请考虑：\n"
                    "1. 是否有模糊或不确定的描述\n"
                    "2. 是否有需要更详细观察的部分\n"
                    "3. 是否有需要验证的假设\n\n"
                    "如果发现需要进一步确认的点，请提出3个最关键的追问。\n"
                    "如果认为当前信息足够完整，请回复'信息完整'。"
                )
            else:
                llm_prompt = (
                    f"Based on the following visual information, please analyze if there are key points that need further confirmation:\n\n"
                    f"Question: {question}\n\n"
                    f"Current Visual Information:\n{initial_visual_info}\n\n"
                    "Consider:\n"
                    "1. Are there any ambiguous or uncertain descriptions?\n"
                    "2. Are there parts that need more detailed observation?\n"
                    "3. Are there assumptions that need verification?\n\n"
                    "If you find points that need further confirmation, please propose 3 most critical follow-up questions.\n"
                    "If you think the current information is sufficient, please reply with 'Information Complete'."
                )

            # 使用LLM生成追问
            messages = [{"role": "user", "content": llm_prompt}]
            text = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
            
            generated_ids = self.llm_model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                top_p=0.9
            )
            generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
            llm_response = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

            # 检查是否需要继续追问
            if "信息完整" in llm_response or "Information Complete" in llm_response:
                break

            # 3. 使用VLM回答追问
            if use_chinese:
                follow_up_prompt = (
                    "请针对以下具体问题，对图片进行更细致的观察和分析：\n\n"
                    f"{llm_response}\n\n"
                    "请提供详细、准确的回答。"
                )
            else:
                follow_up_prompt = (
                    "Please provide detailed and accurate answers to the following specific questions through careful observation:\n\n"
                    f"{llm_response}\n\n"
                    "Please provide detailed and accurate answers."
                )

            vl_messages = [{'role': 'system', 'content': follow_up_prompt},
                          {'role': 'user', 'content': self._prepare_content(image_messages, dataset=dataset)}]
            
            text = self.processor.apply_chat_template(vl_messages, tokenize=False, add_generation_prompt=True)
            images, videos = process_vision_info([vl_messages])
            inputs = self.processor(text=text, images=images, videos=videos, padding=True, return_tensors='pt').to('cuda')
            
            generated_ids = self.model.generate(**inputs, **self.generate_kwargs)
            generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
            follow_up_info = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
            # 更新视觉信息
            initial_visual_info = follow_up_info
            all_visual_info.append(follow_up_info)
            current_iteration += 1

        # 4. 整合所有视觉信息并生成最终答案
        combined_visual_info = "\n\n".join(all_visual_info)
        
        if use_chinese:
            final_prompt = (
                "请基于以下所有视觉信息，通过清晰的推理步骤来回答问题。\n\n"
                "推理步骤：\n"
                "1. 信息整合：\n"
                "   - 整理所有视觉观察结果\n"
                "   - 识别关键信息点\n"
                "   - 建立信息之间的联系\n\n"
                "2. 逻辑推理：\n"
                "   - 基于整合的信息进行推理\n"
                "   - 解释推理过程\n"
                "   - 说明结论的可靠性\n\n"
                f"问题：{question}\n\n"
                f"所有视觉信息：\n{combined_visual_info}\n"
            )
        else:
            final_prompt = (
                "Please answer the question through clear reasoning steps based on all the following visual information.\n\n"
                "Reasoning Steps:\n"
                "1. Information Integration:\n"
                "   - Organize all visual observations\n"
                "   - Identify key information points\n"
                "   - Establish connections between information\n\n"
                "2. Logical Reasoning:\n"
                "   - Conduct reasoning based on integrated information\n"
                "   - Explain the reasoning process\n"
                "   - Demonstrate conclusion reliability\n\n"
                f"Question: {question}\n\n"
                f"All Visual Information:\n{combined_visual_info}\n"
            )

        if hint:
            final_prompt += f"\n提示信息：{hint}" if use_chinese else f"\nHint: {hint}"
        
        if options:
            if use_chinese:
                final_prompt += "\n可选答案：\n"
                for key, value in options.items():
                    final_prompt += f"{key}. {value}\n"
                final_prompt += "\n请通过清晰的推理过程，说明为什么选择特定答案。"
            else:
                final_prompt += "\nOptions:\n"
                for key, value in options.items():
                    final_prompt += f"{key}. {value}\n"
                final_prompt += "\nPlease explain through clear reasoning process why you choose a specific answer."

        # 生成最终答案
        messages = [{"role": "user", "content": final_prompt}]
        text = self.llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.llm_tokenizer([text], return_tensors="pt").to('cuda')
        
        generated_ids = self.llm_model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, generated_ids)]
        final_answer = self.llm_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        if self.verbose:
            print(f"\033[34m初始视觉信息：{all_visual_info[0]}\033[0m")
            for i, info in enumerate(all_visual_info[1:], 1):
                print(f"\033[35m第{i}轮补充信息：{info}\033[0m")
            print(f"\033[33m最终答案：{final_answer}\033[0m")

        return final_answer
