#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Async interface for local multimodal models like QwenVL, InternVL, LLaVA.
Supports both text-only and multimodal inputs with GPU acceleration.

Changes (OOM-focused):
- Force FP16 on CUDA (BF16 is not native on RTX 3090).
- Real multi-GPU sharding via device_map="auto" + max_memory.
- Fix conditional logic to avoid moving a sharded model back onto a single GPU.
- Provide a default PYTORCH_CUDA_ALLOC_CONF to reduce fragmentation if not set.
- (Minor) use_cache=False during generation to lower peak memory.
- (Minor) LLaVA decoding now uses tokenizer and decodes only newly generated tokens.
"""

import asyncio
import logging
import os
from typing import Dict, List, Any, Optional
import concurrent.futures

# ------------------ Fragmentation mitigation (env) ------------------
# If user hasn't set PYTORCH_CUDA_ALLOC_CONF, set a helpful default.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

# Lazy imports to avoid loading unnecessary dependencies
try:
    from PIL import Image
    PIL_AVAILABLE = True
except ImportError:
    Image = None
    PIL_AVAILABLE = False
    logging.warning("PIL not available. Image processing will be limited.")

# torch will be imported when needed

from config.model_config import LocalModelConfig
from models.async_base import AsyncModelInterface

logger = logging.getLogger(__name__)


class AsyncLocalModel(AsyncModelInterface):
    """Async interface for local multimodal models with GPU support."""
    
    def __init__(self, config: LocalModelConfig):
        super().__init__(config)
        self.model_params = config.get_params()
        self.model_path = config.model_path
        self.model = None
        self.tokenizer = None
        self.processor = None
        self.device = None  # Will be set when torch is imported

        # Flag to indicate if the model is sharded by accelerate device_map
        self._use_device_map = False
        
        # Thread pool for CPU-bound operations
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        
        # Load appropriate model based on name
        self._model_type = self._detect_model_type(config.name)
        logger.info(f"Detected model type: {self._model_type} for {config.name}")
        
        # Initialize model asynchronously (assumes an event loop is running)
        asyncio.create_task(self._async_load_model())

    @staticmethod
    def _pick_dtype_for_cuda():
        """
        Always use FP16 on consumer GPUs like RTX 3090.
        BF16 is not natively supported and may silently fall back to FP32.
        """
        import torch
        return torch.float16
    
    def _detect_model_type(self, model_name: str) -> str:
        """Detect model type from name."""
        model_name_lower = model_name.lower()
        if "qwen" in model_name_lower:
            if "qwq" in model_name_lower:
                return "qwen_qwq"
            elif "vl" in model_name_lower:
                return "qwen_vl"
            else:
                return "qwen_text"
        elif "internvl2.5" in model_name_lower:
            return "internvl2.5"
        elif "internvl3-78b" in model_name_lower:
            return "internvl3-78b"
        elif "internvl3-14b" in model_name_lower:
            return "internvl3-14b"
        elif "internvl3" in model_name_lower:
            return "internvl3"
        elif "intern" in model_name_lower:
            return "internvl"
        elif "llava-onevision" in model_name_lower or "onevision" in model_name_lower:
            return "llava_onevision"
        elif "llava" in model_name_lower:
            return "llava"
        elif "kimi" in model_name_lower and "vl" in model_name_lower:
            return "kimi_vl"
        elif "deepseek" in model_name_lower and "vl2" in model_name_lower:
            return "deepseek_vl2"
        else:
            return "unknown"
    
    async def _async_load_model(self):
        """Load model asynchronously to avoid blocking."""
        try:
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(self.executor, self._load_model_sync)
            logger.info(f"Successfully loaded {self._model_type} model: {self.name}")
        except Exception as e:
            logger.error(f"Failed to load model {self.name}: {e}")
            raise
    
    def _load_model_sync(self):
        """Synchronous model loading (runs in thread pool)."""
        if self._model_type in ["qwen_vl", "qwen_text", "qwen_qwq"]:
            self._load_qwen_model()
        elif self._model_type in ["internvl", "internvl2.5", "internvl3", "internvl3-14b", "internvl3-78b"]:
            self._load_internvl_model()
        elif self._model_type == "llava":
            self._load_llava_model()
        elif self._model_type == "llava_onevision":
            self._load_llava_onevision_model()
        elif self._model_type == "kimi_vl":
            self._load_kimi_vl_model()
        else:
            raise ValueError(f"Unsupported model type: {self._model_type}")
    
    def _create_llava_onevision_device_map(self, config):
        """Create custom device map for LLaVA-OneVision following multi-GPU strategy."""
        import math
        import torch
        device_map = {}
        world_size = torch.cuda.device_count()
        
        # Try to get the number of layers from config
        try:
            # LLaVA-OneVision typically uses language_model structure
            if hasattr(config, 'text_config'):
                num_layers = config.text_config.num_hidden_layers
            elif hasattr(config, 'language_model') and hasattr(config.language_model, 'num_hidden_layers'):
                num_layers = config.language_model.num_hidden_layers
            else:
                # Fallback: inspect the model name for layer count estimation
                if "7b" in self.config.name.lower():
                    num_layers = 28  # Observed from the log output
                elif "13b" in self.config.name.lower():
                    num_layers = 40  # Typical for 13B models
                else:
                    num_layers = 28  # Default based on observation
        except:
            num_layers = 28  # Safe fallback based on log
            
        logger.info(f"LLaVA-OneVision: Detected {num_layers} layers for multi-GPU sharding")
        
        # Calculate layers per GPU, reserving GPU 0 for vision components
        # Similar to InternVL, treat first GPU as half capacity due to vision model
        num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
        num_layers_per_gpu = [num_layers_per_gpu] * world_size
        num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
        
        # Distribute language model layers across GPUs
        layer_cnt = 0
        for i, num_layer in enumerate(num_layers_per_gpu):
            for j in range(num_layer):
                if layer_cnt < num_layers:
                    # LLaVA-OneVision uses language_model.model.layers structure
                    device_map[f'language_model.model.layers.{layer_cnt}'] = i
                    # Also map with model. prefix in case the structure uses it
                    device_map[f'model.language_model.model.layers.{layer_cnt}'] = i
                    layer_cnt += 1
        
        # Vision and core components on GPU 0
        device_map['vision_tower'] = 0
        device_map['vision_model'] = 0  # Alternative naming
        device_map['multi_modal_projector'] = 0
        device_map['mm_projector'] = 0  # Alternative naming
        
        # Language model core components on GPU 0 - with correct paths
        device_map['language_model.model.embed_tokens'] = 0
        device_map['language_model.model.norm'] = 0
        device_map['language_model.lm_head'] = 0
        
        # Alternative paths based on actual model structure
        device_map['model.language_model.embed_tokens'] = 0
        device_map['model.language_model.model.embed_tokens'] = 0
        device_map['model.language_model.model.norm'] = 0
        device_map['model.language_model.lm_head'] = 0
        
        # Special LLaVA-OneVision components that caused the error
        device_map['image_newline'] = 0  # The component that was missing
        device_map['model.image_newline'] = 0  # Alternative path
        
        # Additional potential LLaVA components
        device_map['vision_resampler'] = 0
        device_map['image_projector'] = 0
        device_map['projector'] = 0
        device_map['model.vision_tower'] = 0
        device_map['model.vision_model'] = 0
        device_map['model.multi_modal_projector'] = 0
        device_map['model.mm_projector'] = 0
        
        # Ensure the last layer stays on GPU 0 for better performance
        if num_layers > 0:
            device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
            device_map[f'model.language_model.model.layers.{num_layers - 1}'] = 0
            
        logger.info(f"LLaVA-OneVision device map created: {len(device_map)} components across {world_size} GPUs")
        return device_map

    def _setup_device_and_sharding_kwargs(self) -> Dict[str, Any]:
        """
        Common helper to decide device (cpu/cuda) and prepare model kwargs
        for proper multi-GPU sharding via accelerate when available.
        """
        import torch

        # Smart device detection
        if torch.cuda.is_available():
            self.device = "cuda"
            gpu_count = torch.cuda.device_count()
            gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
            logger.info(f"CUDA available. GPUs: {gpu_names} (Count: {gpu_count})")
        else:
            self.device = "cpu"
            logger.warning("No GPU available, using CPU (this will be slow for large models)")

        # Base kwargs
        model_kwargs: Dict[str, Any] = {
            "trust_remote_code": True,
            "low_cpu_mem_usage": True,
            "torch_dtype": self._pick_dtype_for_cuda() if self.device.startswith("cuda") else torch.float32,
        }

        # Try to enable accelerate sharding if possible
        self._use_device_map = False
        try:
            import accelerate  # noqa: F401
            if self.device == "cuda":
                self._use_device_map = True
                
                # Special handling for Qwen2.5-VL-32B - restrict to GPUs 4,5,6,7
                if self._model_type == "qwen_vl" and "32b" in self.config.name.lower():
                    # Custom device map for Qwen2.5-VL-32B using only GPUs 4,5,6,7
                    allowed_gpus = [4, 5, 6, 7]
                    per_gpu_mem_gib = os.environ.get("PER_GPU_MAX_MEMORY_GIB", "40")  # Use more memory for 32B model
                    
                    model_kwargs["max_memory"] = {
                        i: f"{per_gpu_mem_gib}GiB" for i in allowed_gpus
                    }
                    # Set other GPUs to 0 to prevent usage
                    for i in range(torch.cuda.device_count()):
                        if i not in allowed_gpus:
                            model_kwargs["max_memory"][i] = "0GiB"
                    
                    # Add CPU memory for offloading if needed
                    model_kwargs["max_memory"]["cpu"] = "50GiB"
                    model_kwargs["device_map"] = "auto"
                    model_kwargs["offload_folder"] = os.environ.get("OFFLOAD_FOLDER", "./offload_cache")
                    
                    logger.info(f"Qwen2.5-VL-32B: Restricting to GPUs {allowed_gpus} with max_memory={model_kwargs['max_memory']}")
                else:
                    # Standard auto sharding for other models
                    model_kwargs["device_map"] = "auto"
                    # Reserve some headroom to avoid OOM (e.g., 20GiB on 24GiB cards)
                    per_gpu_mem_gib = os.environ.get("PER_GPU_MAX_MEMORY_GIB", "20")
                    model_kwargs["max_memory"] = {
                        i: f"{per_gpu_mem_gib}GiB" for i in range(torch.cuda.device_count())
                    }
                    # Add CPU memory for offloading if needed
                    model_kwargs["max_memory"]["cpu"] = "30GiB"
                    # Optional CPU/offload folder for spill-over
                    model_kwargs["offload_folder"] = os.environ.get("OFFLOAD_FOLDER", "./offload_cache")
                    logger.info(f"Using accelerate sharding with max_memory={model_kwargs['max_memory']} (device indices as integers)")
        except ImportError:
            logger.info("accelerate not available; will load on a single device.")

        return model_kwargs

    def _load_qwen_model(self):
        """Load Qwen model (VL, Text, or QWQ variants) following official implementation."""
        try:
            import torch
            from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer

            # Determine if it's a VL (Vision-Language) model
            is_vl_model = "vl" in self.model_path.lower() or "vision" in self.model_path.lower()

            # Build common kwargs (includes dtype and potential sharding settings)
            model_kwargs = self._setup_device_and_sharding_kwargs()
            if is_vl_model:
                # ---- Qwen2.5-VL path ----
                logger.info(f"Loading Qwen2.5-VL model from {self.model_path}")

                # For 32B model, use official recommended settings
                if "32b" in self.config.name.lower():
                    # Use recommended pixel settings for 32B model
                    min_pixels = 256 * 28 * 28
                    max_pixels = 1280 * 28 * 28
                    self.processor = AutoProcessor.from_pretrained(
                        self.model_path,
                        min_pixels=min_pixels,
                        max_pixels=max_pixels,
                        trust_remote_code=True
                    )
                    logger.info(f"Qwen2.5-VL-32B: Using optimized pixel settings min_pixels={min_pixels}, max_pixels={max_pixels}")
                    
                    # Use official recommended settings for 32B
                    model_kwargs["torch_dtype"] = "auto"  # Official recommendation
                    if self.device.startswith("cuda"):
                        model_kwargs["attn_implementation"] = "flash_attention_2"
                        logger.info("Qwen2.5-VL-32B: Using torch_dtype='auto' and flash_attention_2")
                else:
                    # Load processor with default settings
                    self.processor = AutoProcessor.from_pretrained(
                        self.model_path,
                        trust_remote_code=True
                    )
                    
                    # Use flash attention for other models if available
                    if self.device.startswith("cuda"):
                        try:
                            model_kwargs["attn_implementation"] = "flash_attention_2"
                            logger.info("Using flash_attention_2 for better performance")
                        except Exception as e:
                            logger.warning(f"Could not use flash_attention_2: {e}")

                # Strictly try generation-capable classes only
                model_loaded = False
                last_err = None
                # Try new official class name first (Transformers >= ~4.44)
                try_order = []
                try:
                    from transformers import Qwen2_5_VLForConditionalGeneration  # new-style
                    try_order.append(("Qwen2_5_VLForConditionalGeneration", Qwen2_5_VLForConditionalGeneration))
                except Exception as e:
                    last_err = e
                # Try older alias
                try:
                    from transformers import Qwen2VLForConditionalGeneration  # older alias
                    try_order.append(("Qwen2VLForConditionalGeneration", Qwen2VLForConditionalGeneration))
                except Exception as e:
                    last_err = e
                # Generic VLM generation head
                try:
                    from transformers import AutoModelForVision2Seq
                    try_order.append(("AutoModelForVision2Seq", AutoModelForVision2Seq))
                except Exception as e:
                    last_err = e

                for cls_name, cls in try_order:
                    try:
                        logger.info(f"Trying {cls_name} for Qwen VL")
                        self.model = cls.from_pretrained(self.model_path, **model_kwargs).eval()
                        # MUST have .generate()
                        if not hasattr(self.model, "generate"):
                            raise RuntimeError(f"{cls_name} loaded but has no `.generate()`")
                        model_loaded = True
                        logger.info(f"Qwen VL loaded with {cls_name}")
                        break
                    except Exception as e:
                        last_err = e
                        logger.warning(f"{cls_name} failed: {e}")

                # Absolutely DO NOT fall back to AutoModel for VL (it often lacks generate)
                if not model_loaded:
                    raise RuntimeError(
                        f"Failed to load a generation-capable Qwen VL model. "
                        f"Last error: {last_err}"
                    )

                # Set tokenizer from processor for compatibility
                self.tokenizer = self.processor.tokenizer
                logger.info("Qwen VL processor/tokenizer set successfully")
            else:
                # ---- Text-only Qwen (including QWQ) ----
                logger.info(f"Loading text-only Qwen model from {self.model_path}")

                # Load tokenizer
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_path,
                    trust_remote_code=True,
                    padding_side="left"
                )

                # For QWQ models, prefer flash-attn if on CUDA
                if self._model_type == "qwen_qwq" and self.device == "cuda":
                    model_kwargs["_attn_implementation"] = "flash_attention_2"

                # Load model
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_path,
                    **model_kwargs
                ).eval()
                logger.info("Text-only Qwen model loaded successfully")

            # IMPORTANT: if sharded by accelerate (device_map="auto"), do NOT move the model.
            if not self._use_device_map:
                if hasattr(self, "model") and self.model is not None:
                    logger.info(f"Moving Qwen model to device: {self.device}")
                    self.model = self.model.to(self.device)

            logger.info(
                f"Qwen model ready. device={self.device}, "
                f"sharded={self._use_device_map}, dtype={getattr(self.model, 'dtype', 'n/a')}"
            )

        except Exception as e:
            logger.error(f"Failed to load Qwen model: {e}")
            logger.error(f"Model path: {self.model_path}")
            logger.error(f"Model type: {self._model_type}")
            raise
    
    def _load_internvl_model(self):
        """Load InternVL model."""
        try:
            import torch
            from transformers import AutoModel, AutoTokenizer, AutoConfig
            
            
            # For InternVL3, InternVL3-14B, InternVL3-78B, we need special device mapping for multi-GPU
            if self._model_type in ["internvl3", "internvl3-14b", "internvl3-78b"] and torch.cuda.is_available() and torch.cuda.device_count() > 1:
                # Custom device map for InternVL3
                config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
                device_map = self._create_internvl3_device_map(config)
                
                # InternVL3-14B and InternVL3-78B use bfloat16 and flash attention according to official docs
                if self._model_type in ["internvl3-14b", "internvl3-78b"]:
                    model_kwargs = {
                        "torch_dtype": torch.bfloat16,
                        "low_cpu_mem_usage": True,
                        "trust_remote_code": True,
                        "device_map": device_map,
                        "use_flash_attn": True,  # InternVL3-14B and InternVL3-78B officially use flash attention
                    }
                    
                    # For InternVL3-78B, add 8-bit quantization
                    if self._model_type == "internvl3-78b":
                        model_kwargs["load_in_8bit"] = True
                        logger.info("InternVL3-78B: Using 8-bit quantization for reduced memory usage")
                    
                else:
                    model_kwargs = {
                        "torch_dtype": torch.float16,
                        "low_cpu_mem_usage": True,
                        "trust_remote_code": True,
                        "device_map": device_map,
                    }
                    
                    # Don't use flash attention for InternVL3-8B by default
                    if hasattr(config, "use_flash_attn"):
                        model_kwargs["use_flash_attn"] = False
                    
                self._use_device_map = True
            else:
                # Use standard loading for other InternVL versions
                model_kwargs = self._setup_device_and_sharding_kwargs()

            # For InternVL3-78B, copy the modeling file to model directory first
            if self._model_type == "internvl3-78b":
                self._ensure_internvl_modeling_file()

            # Load tokenizer - InternVL3-78B needs special handling
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_path,
                    trust_remote_code=True,
                    use_fast=False  # InternVL3 example uses use_fast=False
                )
                logger.info(f"Successfully loaded tokenizer for {self._model_type}")
            except Exception as e:
                logger.warning(f"Failed to load tokenizer with AutoTokenizer for {self._model_type}: {e}")
                
                # For InternVL3-78B, try loading model and tokenizer together
                if self._model_type == "internvl3-78b":
                    logger.info("InternVL3-78B: Will load tokenizer along with the model")
                    self.tokenizer = None  # Will be set when model is loaded
                else:
                    raise e
            


            # Set pad_token to avoid warning (only if tokenizer is loaded)
            if self.tokenizer is not None and self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModel.from_pretrained(
                self.model_path,
                **model_kwargs
            ).eval()
            
            # For InternVL3-78B, extract tokenizer from the loaded model if not already loaded
            if self._model_type == "internvl3-78b" and self.tokenizer is None:
                try:
                    # Try to get tokenizer from model's tokenizer attribute
                    if hasattr(self.model, 'tokenizer'):
                        self.tokenizer = self.model.tokenizer
                        logger.info("InternVL3-78B: Successfully extracted tokenizer from model")
                    elif hasattr(self.model, 'language_model') and hasattr(self.model.language_model, 'tokenizer'):
                        self.tokenizer = self.model.language_model.tokenizer
                        logger.info("InternVL3-78B: Successfully extracted tokenizer from language model")
                    else:
                        # Create tokenizer using the same path but with a different approach
                        logger.info("InternVL3-78B: Creating tokenizer manually from model path")
                        # Use the transformers library's lower-level approach
                        from transformers.models.auto.tokenization_auto import AutoTokenizer
                        tokenizer_files = ['tokenizer.json', 'tokenizer.model', 'vocab.txt']
                        if any(os.path.exists(os.path.join(self.model_path, f)) for f in tokenizer_files):
                            self.tokenizer = AutoTokenizer.from_pretrained(
                                self.model_path, 
                                trust_remote_code=True, 
                                use_fast=False,
                                force_download=False,
                                local_files_only=True
                            )
                        else:
                            # Final fallback: use a compatible tokenizer
                            logger.warning("InternVL3-78B: Using Llama-2 tokenizer as fallback")
                            self.tokenizer = AutoTokenizer.from_pretrained(
                                "meta-llama/Llama-2-7b-hf",
                                trust_remote_code=True,
                                use_fast=False
                            )
                    
                    # Set pad_token if needed
                    if self.tokenizer.pad_token is None:
                        self.tokenizer.pad_token = self.tokenizer.eos_token
                        
                except Exception as e:
                    logger.error(f"Failed to set up tokenizer for InternVL3-78B: {e}")
                    raise

            # IMPORTANT: do not move if sharded
            if not self._use_device_map:
                self.model = self.model.to(self.device)

            logger.info(
                f"InternVL model ready. device={self.device}, "
                f"sharded={self._use_device_map}, dtype={getattr(self.model, 'dtype', 'n/a')}"
            )
        except Exception as e:
            logger.error(f"Failed to load InternVL model: {e}")
            raise
    
    
    
    
    def _create_internvl3_device_map(self, config):
        """Create custom device map for InternVL3 following official implementation."""
        import math
        import torch
        device_map = {}
        world_size = torch.cuda.device_count()
        
        try:
            num_layers = config.llm_config.num_hidden_layers
        except:
            # Fallback if config structure is different based on model type
            if self._model_type == "internvl3-78b":
                num_layers = 80  # Default for 78B model
            else:
                num_layers = 32  # Default for 8B/14B model
            
        logger.info(f"InternVL3 ({self._model_type}): Detected {num_layers} layers for multi-GPU sharding across {world_size} GPUs")
        
        # Since the first GPU will be used for ViT, treat it as half a GPU.
        num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
        num_layers_per_gpu = [num_layers_per_gpu] * world_size
        num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
        
        layer_cnt = 0
        for i, num_layer in enumerate(num_layers_per_gpu):
            for j in range(num_layer):
                if layer_cnt < num_layers:
                    device_map[f'language_model.model.layers.{layer_cnt}'] = i
                    layer_cnt += 1
                    
        # Vision model and related components on first GPU
        device_map['vision_model'] = 0
        device_map['mlp1'] = 0
        device_map['language_model.model.tok_embeddings'] = 0
        device_map['language_model.model.embed_tokens'] = 0
        device_map['language_model.output'] = 0
        device_map['language_model.model.norm'] = 0
        device_map['language_model.model.rotary_emb'] = 0
        device_map['language_model.lm_head'] = 0
        device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
        
        logger.info(f"InternVL3 device map created: {len(device_map)} components, layers distributed as {num_layers_per_gpu}")
        return device_map

    def _ensure_internvl_modeling_file(self):
        """Ensure modeling_internvl_chat.py is available in the model directory for InternVL3-78B."""
        import os
        import shutil

        # Path to the modeling file in our project
        project_modeling_file = os.path.join(os.path.dirname(__file__), "modeling_internvl_chat.py")

        # Path in the model directory where transformers expects it
        model_modeling_file = os.path.join(self.model_path, "modeling_internvl_chat.py")

        # Check if the file already exists in model directory
        if not os.path.exists(model_modeling_file):
            try:
                # Copy the file to model directory
                shutil.copy2(project_modeling_file, model_modeling_file)
                logger.info(f"Copied modeling_internvl_chat.py to model directory: {model_modeling_file}")
            except Exception as e:
                logger.error(f"Failed to copy modeling file to model directory: {e}")
                raise
        else:
            logger.info(f"Modeling file already exists in model directory: {model_modeling_file}")

    def _build_transform_for_internvl(self, input_size=448):
        """Build image transformation for InternVL models."""
        import torchvision.transforms as T
        from torchvision.transforms.functional import InterpolationMode
        
        IMAGENET_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_STD = (0.229, 0.224, 0.225)
        
        transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])
        return transform
    
    def _find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
        """Find closest aspect ratio for dynamic preprocessing."""
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio
    
    def _dynamic_preprocess_internvl(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
        """Dynamic preprocessing for InternVL3 following official implementation."""
        orig_width, orig_height = image.size
        aspect_ratio = orig_width / orig_height

        # calculate the existing image aspect ratio
        target_ratios = set(
            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
            i * j <= max_num and i * j >= min_num)
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # find the closest aspect ratio to the target
        target_aspect_ratio = self._find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size)

        # calculate the target width and height
        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]
        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

        # resize the image
        resized_img = image.resize((target_width, target_height))
        processed_images = []
        for i in range(blocks):
            box = (
                (i % (target_width // image_size)) * image_size,
                (i // (target_width // image_size)) * image_size,
                ((i % (target_width // image_size)) + 1) * image_size,
                ((i // (target_width // image_size)) + 1) * image_size
            )
            # split the image
            split_img = resized_img.crop(box)
            processed_images.append(split_img)
        assert len(processed_images) == blocks
        if use_thumbnail and len(processed_images) != 1:
            thumbnail_img = image.resize((image_size, image_size))
            processed_images.append(thumbnail_img)
        return processed_images
    
    def _process_image_for_internvl3(self, image, input_size=448, max_num=12):
        """Process a single image for InternVL3."""
        import torch
        
        transform = self._build_transform_for_internvl(input_size=input_size)
        images = self._dynamic_preprocess_internvl(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(img) for img in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values
    
    def _load_llava_model(self):
        """Load LLaVA model."""
        try:
            import torch
            from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

            # Build common kwargs and device
            model_kwargs = self._setup_device_and_sharding_kwargs()

            self.processor = LlavaNextProcessor.from_pretrained(self.model_path)
            self.model = LlavaNextForConditionalGeneration.from_pretrained(
                self.model_path,
                **model_kwargs
            ).eval()

            # IMPORTANT: do not move if sharded
            if not self._use_device_map:
                self.model = self.model.to(self.device)

            logger.info(
                f"LLaVA model ready. device={self.device}, "
                f"sharded={self._use_device_map}, dtype={getattr(self.model, 'dtype', 'n/a')}"
            )
        except Exception as e:
            logger.error(f"Failed to load LLaVA model: {e}")
            raise
    
    def _load_llava_onevision_model(self):
        """Load LLaVA-OneVision model with multi-GPU support."""
        try:
            import torch
            from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoConfig
            
            # For LLaVA-OneVision, use auto device mapping with multi-GPU constraints
            if torch.cuda.is_available() and torch.cuda.device_count() > 1:
                # Use standard auto device mapping which is more reliable
                model_kwargs = self._setup_device_and_sharding_kwargs()
                
                # Add attention implementation for faster inference
                if "llava-onevision" in self.config.name.lower():
                    model_kwargs["attn_implementation"] = "flash_attention_2"
                    logger.info("LLaVA-OneVision: Using Flash Attention 2 for faster inference")
                    
                logger.info("Using auto device mapping for LLaVA-OneVision multi-GPU setup")
            else:
                # Use standard loading for single GPU or CPU
                model_kwargs = self._setup_device_and_sharding_kwargs()
                
                # Add attention implementation for faster inference on single GPU
                if "llava-onevision" in self.config.name.lower() and self.device.startswith("cuda"):
                    model_kwargs["attn_implementation"] = "flash_attention_2"
                    logger.info("LLaVA-OneVision: Using Flash Attention 2 for faster inference")
            
            # Try to load processor and model
            try:
                self.processor = AutoProcessor.from_pretrained(
                    self.model_path,
                    trust_remote_code=True
                )
                self.model = LlavaOnevisionForConditionalGeneration.from_pretrained(
                    self.model_path,
                    **model_kwargs
                ).eval()
            except:
                # Fallback to generic loading
                from transformers import AutoModelForVision2Seq
                self.processor = AutoProcessor.from_pretrained(
                    self.model_path,
                    trust_remote_code=True
                )
                self.model = AutoModelForVision2Seq.from_pretrained(
                    self.model_path,
                    **model_kwargs
                ).eval()
            
            # Set tokenizer from processor
            self.tokenizer = self.processor.tokenizer
            
            # IMPORTANT: do not move if sharded
            if not self._use_device_map:
                self.model = self.model.to(self.device)
            
            logger.info(
                f"LLaVA-OneVision model ready. device={self.device}, "
                f"sharded={self._use_device_map}, dtype={getattr(self.model, 'dtype', 'n/a')}"
            )
        except Exception as e:
            logger.error(f"Failed to load LLaVA-OneVision model: {e}")
            raise
    
    def _load_kimi_vl_model(self):
        """Load Kimi-VL model following official implementation."""
        try:
            import torch
            from transformers import AutoProcessor, AutoModelForCausalLM

            # Use official recommended settings
            model_kwargs = {
                "torch_dtype": "auto",
                "device_map": "auto",
                "trust_remote_code": True,
            }

            # For multi-GPU setup, add memory constraints if needed
            if torch.cuda.is_available() and torch.cuda.device_count() > 1:
                # Reserve some headroom to avoid OOM
                per_gpu_mem_gib = os.environ.get("PER_GPU_MAX_MEMORY_GIB", "20")
                model_kwargs["max_memory"] = {
                    i: f"{per_gpu_mem_gib}GiB" for i in range(torch.cuda.device_count())
                }
                # Add CPU memory for offloading if needed
                model_kwargs["max_memory"]["cpu"] = "30GiB"
                self._use_device_map = True
            else:
                self._use_device_map = False

            # Load processor
            self.processor = AutoProcessor.from_pretrained(
                self.model_path,
                trust_remote_code=True
            )
            self.tokenizer = self.processor.tokenizer

            # Load model with official settings
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                **model_kwargs
            ).eval()

            logger.info(
                f"Kimi-VL model ready with official settings. "
                f"device_map={model_kwargs.get('device_map', 'n/a')}, "
                f"torch_dtype={model_kwargs.get('torch_dtype', 'n/a')}"
            )
        except Exception as e:
            logger.error(f"Failed to load Kimi-VL model: {e}")
            raise
    
    async def generate_text_async(
        self, 
        prompt: str, 
        images: Optional[List[Any]] = None, 
        use_thinking: bool = False
    ) -> str:
        """
        Generate text using local model asynchronously.
        
        Args:
            prompt: Text prompt
            images: Optional list of images (file paths or PIL Images)
            use_thinking: Whether to use Chain-of-Thought reasoning
            
        Returns:
            Generated text
        """
        # Wait for model to be loaded
        while self.model is None:
            await asyncio.sleep(0.1)
        
        try:
            # Process images if provided
            processed_images = None
            if images:
                processed_images = await self._process_images_async(images)
            
            # Generate text based on model type
            if self._model_type in ["qwen_vl", "qwen_text", "qwen_qwq"]:
                return await self._generate_with_qwen_async(prompt, processed_images, use_thinking)
            elif self._model_type in ["internvl", "internvl2.5", "internvl3", "internvl3-14b", "internvl3-78b"]:
                return await self._generate_with_internvl_async(prompt, processed_images)
            elif self._model_type == "llava":
                return await self._generate_with_llava_async(prompt, processed_images)
            elif self._model_type == "llava_onevision":
                return await self._generate_with_llava_onevision_async(prompt, processed_images)
            elif self._model_type == "kimi_vl":
                return await self._generate_with_kimi_vl_async(prompt, processed_images)
            else:
                raise ValueError(f"Unsupported model type: {self._model_type}")
                
        except Exception as e:
            logger.error(f"Error generating text with {self.name}: {e}")
            raise
    
    async def _process_images_async(self, images: List[Any]) -> List[Any]:
        """Process images asynchronously."""
        async def load_single_image(image_input):
            if not PIL_AVAILABLE:
                raise RuntimeError("PIL not installed. Install pillow to enable image inputs.")
            if isinstance(image_input, str):
                # Load from file path
                return Image.open(image_input).convert("RGB")
            elif hasattr(image_input, "convert"):
                # Duck-typing to avoid referencing Image.Image when PIL is missing
                return image_input.convert("RGB")
            else:
                raise ValueError(f"Unsupported image type: {type(image_input)}")
        
        # Load all images concurrently
        tasks = [load_single_image(img) for img in images]
        return await asyncio.gather(*tasks)
    
    async def _generate_with_qwen_async(
        self, 
        prompt: str, 
        images: Optional[List[Any]] = None,
        use_thinking: bool = False
    ) -> str:
        """Generate text with Qwen models asynchronously."""
        def _generate_sync():
            try:
                # Import inside sync section
                import torch

                # Import vision utilities if available (for official Qwen2.5-VL support)
                try:
                    from qwen_vl_utils import process_vision_info
                    has_qwen_vl_utils = True
                    logger.info("Using qwen_vl_utils for vision processing")
                except ImportError:
                    has_qwen_vl_utils = False
                    logger.info("qwen_vl_utils not available, using fallback method")
                
                # Check if we have a processor (VL model) or just tokenizer (text model)
                has_processor = hasattr(self, 'processor') and self.processor is not None
                
                if self._model_type == "qwen_vl" and images and has_processor:
                    # Multimodal input for Qwen-VL following official pattern
                    messages = [
                        {"role": "user", "content": []}
                    ]
                    # Add images first, then text
                    for image in images:
                        messages[0]["content"].append({"type": "image", "image": image})
                    messages[0]["content"].append({"type": "text", "text": prompt})
                    
                    tokenizer = self.processor.tokenizer
                    
                    if has_qwen_vl_utils:
                        try:
                            text = self.processor.apply_chat_template(
                                messages, tokenize=False, add_generation_prompt=True
                            )
                            # Official qwen_vl_utils approach (exactly as in official example)
                            image_inputs, video_inputs = process_vision_info(messages)
                            inputs = self.processor(
                                text=[text],
                                images=image_inputs,
                                videos=video_inputs,
                                padding=True,
                                return_tensors="pt"
                            )
                            logger.debug("Successfully used qwen_vl_utils for vision processing")
                        except Exception as e:
                            logger.warning(f"qwen_vl_utils processing failed: {e}, using fallback")
                            has_qwen_vl_utils = False
                    
                    if not has_qwen_vl_utils:
                        # Fallback processing when qwen_vl_utils is not available or failed
                        try:
                            text = self.processor.apply_chat_template(
                                messages, tokenize=False, add_generation_prompt=True
                            )
                            inputs = self.processor(
                                text=[text],
                                images=images,
                                padding=True,
                                return_tensors="pt"
                            )
                            logger.debug("Using fallback vision processing with chat template")
                        except Exception as e:
                            logger.warning(f"Chat template failed: {e}, using simple processing")
                            # Last resort: simple text processing
                            inputs = self.processor(
                                text=prompt,
                                images=images,
                                padding=True,
                                return_tensors="pt"
                            )
                    # Move inputs to appropriate device
                    if self._use_device_map:
                        # For device_map models, move to the primary GPU (cuda:0 is mapped to the first available GPU)
                        if "32b" in self.config.name.lower():
                            # For 32B model using GPUs 4,5,6,7, use GPU 4 as primary
                            inputs = {k: v.to("cuda:4") if hasattr(v, 'to') else v for k, v in inputs.items()}
                        else:
                            inputs = {k: v.to("cuda:0") if hasattr(v, 'to') else v for k, v in inputs.items()}
                    else:
                        inputs = inputs.to(self.device)
                else:
                    # Text-only input
                    messages = [{"role": "user", "content": prompt}]
                    tokenizer = self.tokenizer if hasattr(self, "tokenizer") else self.processor.tokenizer
                    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    inputs = tokenizer([text], return_tensors="pt", padding=True)
                    
                    # Move text inputs to appropriate device
                    if self._use_device_map:
                        if "32b" in self.config.name.lower():
                            # For 32B model using GPUs 4,5,6,7, use GPU 4 as primary
                            inputs = {k: v.to("cuda:4") if hasattr(v, 'to') else v for k, v in inputs.items()}
                        else:
                            inputs = {k: v.to("cuda:0") if hasattr(v, 'to') else v for k, v in inputs.items()}
                    else:
                        inputs = inputs.to(self.device)
                
                # Ensure tokenizer is defined
                if 'tokenizer' not in locals():
                    tokenizer = self.tokenizer if hasattr(self, 'tokenizer') else self.processor.tokenizer
                
                # Generation parameters following official example
                if "32b" in self.config.name.lower() and self._model_type == "qwen_vl":
                    # Use minimal parameters like official example for 32B model
                    generation_kwargs = {
                        "max_new_tokens": self.model_params.get("max_new_tokens", 128),  # Official example uses 128
                    }
                else:
                    # Standard parameters for other Qwen models
                    generation_kwargs = {
                        "max_new_tokens": self.model_params.get("max_new_tokens", 512),
                        "do_sample": False,
                        "use_cache": False,   # lower VRAM usage
                    }
                    if self.model_params.get("use_sampling", False):
                        generation_kwargs.update({
                            "do_sample": True,
                            "temperature": self.model_params.get("temperature", 0.7),
                            "top_p": self.model_params.get("top_p", 0.9),
                        })
                
                # For Qwen2.5-VL-32B, completely skip token IDs to avoid vocab range issues
                if "32b" in self.config.name.lower() and self._model_type == "qwen_vl":
                    logger.info("Qwen2.5-VL-32B: Skipping token ID settings to avoid vocab range errors")
                    # Don't set any special token IDs for 32B model
                else:
                    # Safe tokenizer ID handling for other models
                    def get_safe_token_id(token_attr, fallback_attr=None):
                        """Safely get token ID, ensuring it's valid for the vocabulary."""
                        if not hasattr(tokenizer, token_attr):
                            return None
                        token_id = getattr(tokenizer, token_attr)
                        if token_id is None and fallback_attr and hasattr(tokenizer, fallback_attr):
                            token_id = getattr(tokenizer, fallback_attr)
                        # Verify token ID is within vocabulary range
                        if token_id is not None and hasattr(tokenizer, 'vocab_size'):
                            if token_id >= tokenizer.vocab_size or token_id < 0:
                                logger.debug(f"Token ID {token_id} out of vocab range [0, {tokenizer.vocab_size}), skipping")
                                return None
                        return token_id

                    # Set token IDs safely for non-32B models
                    safe_eos_token_id = get_safe_token_id('eos_token_id')
                    safe_pad_token_id = get_safe_token_id('pad_token_id', 'eos_token_id')
                    
                    if safe_eos_token_id is not None:
                        generation_kwargs["eos_token_id"] = safe_eos_token_id
                    if safe_pad_token_id is not None:
                        generation_kwargs["pad_token_id"] = safe_pad_token_id
                if self._model_type == "qwen_qwq" and use_thinking:
                    # Available only for QWQ variants that support it
                    generation_kwargs["thinking"] = True
                
                # Generate following official example pattern
                with torch.no_grad():
                    try:
                        # Simple generation like official example  
                        output_ids = self.model.generate(**inputs, **generation_kwargs)
                    except Exception as e:
                        if "32b" in self.config.name.lower():
                            logger.error(f"Qwen2.5-VL-32B generation failed: {e}")
                            # Try with even simpler parameters as fallback
                            try:
                                logger.info("Qwen2.5-VL-32B: Attempting fallback with minimal parameters")
                                minimal_kwargs = {"max_new_tokens": 64}
                                output_ids = self.model.generate(**inputs, **minimal_kwargs)
                                logger.info("Qwen2.5-VL-32B: Fallback generation succeeded")
                            except Exception as e2:
                                logger.error(f"Qwen2.5-VL-32B fallback also failed: {e2}")
                                raise e  # Raise original error
                        else:
                            raise e

                # Decode only new tokens
                input_len = inputs["input_ids"].shape[1]
                gen_ids = [out[input_len:] for out in output_ids]
                response = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)[0]
                return response.strip()
                
            except Exception as e:
                logger.error(f"Error in Qwen generation: {e}")
                logger.error(f"Model type: {self._model_type}")
                logger.error(f"Has processor: {hasattr(self, 'processor')}")
                logger.error(f"Has images: {bool(images)}")
                raise
        
        # Run generation in thread pool to avoid blocking
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, _generate_sync)
    
    async def _generate_with_internvl_async(
        self, 
        prompt: str, 
        images: Optional[List[Any]] = None
    ) -> str:
        """Generate text with InternVL asynchronously."""
        def _generate_sync():
            try:
                import torch
                
                
                # For all InternVL models, use consistent logic
                if images:
                    # Multimodal chat API for InternVL
                    from PIL import Image
                    
                    # Ensure images are PIL Images
                    pil_images = []
                    for img in images:
                        if isinstance(img, Image.Image):
                            pil_images.append(img)
                        else:
                            pil_images.append(img)
                    
                    # InternVL3, InternVL3-14B, InternVL3-78B use pixel values instead of PIL images
                    if self._model_type in ["internvl3", "internvl3-14b", "internvl3-78b"]:
                        # Process all images to pixel values and concatenate
                        all_pixel_values = []
                        for img in pil_images:
                            pixel_values = self._process_image_for_internvl3(img)
                            all_pixel_values.append(pixel_values)
                        
                        # Concatenate all pixel values along batch dimension
                        pixel_values = torch.cat(all_pixel_values, dim=0)
                        
                        # Move pixel values to appropriate device with correct dtype
                        if self._use_device_map:
                            # For device_map, pixel values should go to cuda:0 where vision model is
                            if self._model_type in ["internvl3-14b", "internvl3-78b"]:
                                pixel_values = pixel_values.to(torch.bfloat16).cuda(0)
                            else:
                                pixel_values = pixel_values.to(torch.float16).cuda(0)
                        else:
                            if self._model_type in ["internvl3-14b", "internvl3-78b"]:
                                pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
                            else:
                                pixel_values = pixel_values.to(torch.float16).to(self.device)
                        
                        # Format prompt with multiple image placeholders
                        image_placeholders = "<image>" * len(pil_images)
                        question = f"{image_placeholders}\n{prompt}"
                        
                        # Call chat method following the official example
                        response = self.model.chat(
                            self.tokenizer,
                            pixel_values,
                            question,
                            generation_config=dict(
                                max_new_tokens=self.model_params.get("max_new_tokens", 1024),
                                do_sample=True,
                                temperature=self.model_params.get("temperature", 0.1),
                            )
                        )
                    else:
                        # For other InternVL versions (InternVL2.5, etc.)
                        # Format prompt with multiple image placeholders
                        image_placeholders = "<image>" * len(pil_images)
                        question = f"{image_placeholders}\n{prompt}"
                        
                        response = self.model.chat(
                            self.tokenizer,
                            pixel_values=None,
                            question=question,
                            generation_config=dict(
                                max_new_tokens=self.model_params.get("max_new_tokens", 1024),
                                temperature=self.model_params.get("temperature", 0.1),
                                do_sample=True,
                            ),
                            images=pil_images
                        )
                    return response
                else:
                    # Text-only generation
            
                    if self._use_device_map:
                        # When using device_map, send inputs to the first GPU where embeddings typically reside
                        inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0")
                    else:
                        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
                    
                    with torch.no_grad():
                        response = self.model.chat(
                                self.tokenizer,
                                pixel_values=None,  # None for text-only mode
                                question=prompt,
                                generation_config=dict(
                                    max_new_tokens=self.model_params.get("max_new_tokens", 1024),
                                    temperature=self.model_params.get("temperature", 0.1),
                                    do_sample=True,
                                ),
                                history=None,  # No history for single-turn conversations
                                return_history=False  # We don't need history
                            )
                    return response
            except Exception as e:
                logger.error(f"Error in InternVL generation: {e}")
                raise
        
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, _generate_sync)
    
    async def _generate_with_llava_async(
        self,
        prompt: str,
        images: Optional[List[Any]] = None
    ) -> str:
        """Generate text with LLaVA asynchronously following official LLaVA format."""
        def _generate_sync():
            try:
                import torch
                from PIL import Image

                # Prepare conversation format following official LLaVA example (transformers>=v4.48 style)
                if images:
                    # Multimodal conversation
                    # LLaVA typically processes one image at a time, use the first image
                    image = images[0] if images else None
                    if isinstance(image, str):
                        # If it's a file path, load the image
                        image = Image.open(image).convert("RGB")

                    # Use the newer chat template format (transformers>=v4.48)
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image},  # Pass PIL image directly
                                {"type": "text", "text": prompt},
                            ]
                        }
                    ]

                    # Apply chat template with tokenization
                    inputs = self.processor.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=True,
                        return_dict=True,
                        return_tensors="pt"
                    )
                else:
                    # Text-only conversation using newer format
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt}
                            ]
                        }
                    ]

                    # Apply chat template with tokenization for text-only
                    inputs = self.processor.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=True,
                        return_dict=True,
                        return_tensors="pt"
                    )

                # Move inputs to appropriate device
                if not self._use_device_map:
                    inputs = {k: v.to(self.device) if hasattr(v, 'to') else v for k, v in inputs.items()}

                # Generate with optimized settings following official example
                with torch.no_grad():
                    output_ids = self.model.generate(
                        **inputs,
                        max_new_tokens=self.model_params.get("max_new_tokens", 512),
                        temperature=self.model_params.get("temperature", 0.1),
                        do_sample=True,
                        use_cache=True,  # Enable cache for better performance
                    )

                # Decode response following official pattern
                # For LLaVA, decode the full output and skip special tokens
                full_response = self.processor.decode(output_ids[0], skip_special_tokens=True).strip()

                # Extract only the assistant's response from the full conversation
                # Look for ASSISTANT: marker and extract everything after it
                if "ASSISTANT:" in full_response:
                    response = full_response.split("ASSISTANT:", 1)[1].strip()
                elif "Assistant:" in full_response:
                    response = full_response.split("Assistant:", 1)[1].strip()
                else:
                    # Fallback: try to remove the input prompt more carefully
                    # Remove the processed prompt text if it appears at the beginning
                    response = full_response
                    if prompt_text in response:
                        # Find where the actual response starts (after the prompt)
                        prompt_end_pos = response.find(prompt_text) + len(prompt_text)
                        if prompt_end_pos < len(response):
                            response = response[prompt_end_pos:].strip()
                        else:
                            # If we can't find a clean separation, keep the full response
                            response = full_response

                return response

            except Exception as e:
                logger.error(f"Error in LLaVA generation: {e}")
                logger.error(f"Model type: {self._model_type}")
                logger.error(f"Has images: {bool(images)}")
                raise

        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, _generate_sync)
    
    async def _generate_with_llava_onevision_async(
        self, 
        prompt: str, 
        images: Optional[List[Any]] = None
    ) -> str:
        """Generate text with LLaVA-OneVision asynchronously."""
        def _generate_sync():
            try:
                import torch
                
                if images:
                    # Multimodal input with images
                    # LLaVA-OneVision typically processes one image at a time
                    image = images[0]
                    
                    # Prepare conversation format
                    conversation = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image"},
                                {"type": "text", "text": prompt}
                            ]
                        }
                    ]
                    
                    # Apply chat template
                    prompt_text = self.processor.apply_chat_template(
                        conversation, 
                        add_generation_prompt=True
                    )
                    
                    # Prepare inputs
                    inputs = self.processor(
                        text=prompt_text,
                        images=image,
                        return_tensors="pt"
                    )
                else:
                    # Text-only input - use model.chat following official pattern
                    logger.info("LLaVA-OneVision: Processing text-only input (no images provided)")

                    # Prepare conversation format for text-only
                    conversation = [
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ]

                    # Apply chat template for text-only
                    prompt_text = self.processor.apply_chat_template(
                        conversation,
                        add_generation_prompt=True
                    )

                    # Use model.chat for text-only generation
                    with torch.no_grad():
                        # Special optimization for llava-onevision models
                        if "llava-onevision" in self.config.name.lower():
                            # Optimized settings for math problems
                            generation_config = {
                                "max_new_tokens": self.model_params.get("max_new_tokens", 512),  # Reasonable limit for math
                                "temperature": 0.1,  # Low temperature for consistent answers
                                "do_sample": True,  # Enable sampling to avoid repetition
                                "use_cache": True,  # Enable KV cache for faster generation
                                "pad_token_id": self.tokenizer.pad_token_id,
                                "eos_token_id": self.tokenizer.eos_token_id,
                                "top_p": 0.95,  # Nucleus sampling
                                "top_k": 50,  # Top-k sampling
                            }
                        else:
                            # Default settings for other models
                            generation_config = {
                                "max_new_tokens": self.model_params.get("max_new_tokens", 1024),
                                "temperature": self.model_params.get("temperature", 0.1),
                                "do_sample": True,
                                "use_cache": False,  # lower VRAM usage
                            }

                    # Use model.generate for text-only generation following official pattern
                    inputs = self.processor(text=prompt_text, return_tensors='pt')
                    if not self._use_device_map:
                        inputs = {k: v.to(self.device) if hasattr(v, 'to') else v for k, v in inputs.items()}

                    # Generate using the same pattern as official example
                    output_ids = self.model.generate(
                        **inputs,
                        **generation_config
                    )

                    # Decode only newly generated tokens
                    input_len = inputs["input_ids"].shape[1]
                    gen_ids = output_ids[0][input_len:]
                    response = self.processor.decode(gen_ids, skip_special_tokens=True).strip()
                    return response
                
            except Exception as e:
                logger.error(f"Error in LLaVA-OneVision generation: {e}")
                raise
        
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, _generate_sync)
    
    async def _generate_with_kimi_vl_async(
        self,
        prompt: str,
        images: Optional[List[Any]] = None
    ) -> str:
        """Generate text with Kimi-VL asynchronously following official implementation."""
        def _generate_sync():
            try:
                import torch

                if images and self.processor:
                    # Multimodal input following official pattern
                    # For multiple images, use the first one (Kimi-VL typically processes one image at a time)
                    image = images[0]

                    # Create messages following official format
                    messages = [
                        {
                            "role": "user",
                            "content": [
                                {"type": "image", "image": image},  # Pass PIL Image directly
                                {"type": "text", "text": prompt}
                            ]
                        }
                    ]

                    # Apply chat template
                    text = self.processor.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=False
                    )

                    # Process inputs following official pattern
                    inputs = self.processor(
                        images=image,
                        text=text,
                        return_tensors="pt",
                        padding=True,
                        truncation=True
                    )

                    # Move to model device if not using device_map
                    if not self._use_device_map:
                        inputs = {k: v.to(self.model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
                else:
                    # Text-only input
                    messages = [{"role": "user", "content": prompt}]
                    text = self.processor.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        tokenize=False
                    )
                    inputs = self.processor(
                        text=text,
                        return_tensors="pt",
                        padding=True
                    )
                    # Move to model device if not using device_map
                    if not self._use_device_map:
                        inputs = {k: v.to(self.model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}

                # Generation parameters following official pattern
                generation_kwargs = {
                    "max_new_tokens": self.model_params.get("max_new_tokens", 512),
                }

                # Add sampling parameters if enabled
                if self.model_params.get("use_sampling", False):
                    generation_kwargs.update({
                        "do_sample": True,
                        "temperature": self.model_params.get("temperature", 0.7),
                        "top_p": self.model_params.get("top_p", 0.9),
                    })

                # Generate following official pattern
                with torch.no_grad():
                    generated_ids = self.model.generate(**inputs, **generation_kwargs)

                # Decode following official pattern
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                response = self.processor.batch_decode(
                    generated_ids_trimmed,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False
                )[0]

                return response.strip()

            except Exception as e:
                logger.error(f"Error in Kimi-VL generation: {e}")
                raise

        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, _generate_sync)
    
    def __del__(self):
        """Cleanup resources."""
        if hasattr(self, 'executor'):
            self.executor.shutdown(wait=False)
