# core/models/qwen_handler.py
"""
Qwen2.5-VL model handler.

Responsible for model loading, initialization, and inference.
"""

import os
import torch
import tempfile
import decord
from typing import List, Dict, Any, Tuple
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import smart_resize
from torchvision.transforms import InterpolationMode
from torchvision import transforms

from config.settings import QWEN_CONFIG


# Video processing constants
IMAGE_FACTOR = 28
FRAME_FACTOR = 2
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))


class QwenModelHandler:
    """Qwen model handler."""
    
    def __init__(self, model_path: str, device_map: str = "auto", **kwargs):
        self.model_path = model_path
        self.device_map = device_map
        self.model = None
        self.processor = None
        self.config = kwargs
        
    def load_model(self):
        """Load the model and processor."""
        if self.model is None or self.processor is None:
            print(f"[Qwen] Loading model from {self.model_path} with device_map={self.device_map}")
            
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.model_path,
                torch_dtype=getattr(torch, QWEN_CONFIG["torch_dtype"]),
                device_map=self.device_map,
                trust_remote_code=True,
                attn_implementation=QWEN_CONFIG["attention_implementation"]
            )
            self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
            print(f"[Qwen] Model loaded successfully")
        
        return self.model, self.processor
    
    def get_model_info(self) -> Dict[str, Any]:
        """Return model info."""
        return {
            "model_path": self.model_path,
            "device_map": self.device_map,
            "torch_dtype": QWEN_CONFIG["torch_dtype"],
            "attention_implementation": QWEN_CONFIG["attention_implementation"],
            "config": self.config
        }


def initialize_model_specific(model_path: str, device_map: str = "auto") -> Tuple[Any, Any]:
    """
    Initialize a specific Qwen model.

    Args:
        model_path: Model path.
        device_map: Device map.

    Returns:
        (model, processor)
    """
    handler = QwenModelHandler(model_path, device_map)
    return handler.load_model()


def initialize_model(model_path: str = None) -> Tuple[Any, Any]:
    """
    Initialize the default Qwen model (backward-compatible entrypoint).

    Args:
        model_path: Optional model path (defaults to config).

    Returns:
        (model, processor)
    """
    if model_path is None:
        model_path = QWEN_CONFIG["default_single_model"]
    
    return initialize_model_specific(model_path, "auto")


@torch.no_grad()
def get_qwen_response_generic(
    model, processor, prompt: str, video_path: str, frame_indices: List[int], generation_kwargs: Dict
) -> Dict[str, Any]:
    """
    Generic Qwen video inference function.

    Args:
        model: Qwen model.
        processor: Qwen processor.
        prompt: Prompt text.
        video_path: Video path.
        frame_indices: List of frame indices.
        generation_kwargs: Generation parameters.

    Returns:
        {"text": str, "tokens": {"prompt": int, "output": int, "total": int}}
    """
    # Load video frames
    try:
        vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
    except Exception as e:
        print(f"[Qwen] Decord error: {e}")
        return {"text": "", "tokens": {"prompt": 0, "output": 0, "total": 0}}
    
    # Ensure frame indices are valid
    valid_indices = sorted([i for i in frame_indices if 0 <= i < len(vr)]) or [0]
    video_np = vr.get_batch(valid_indices).asnumpy()
    video_tensor = torch.tensor(video_np).permute(0, 3, 1, 2)
    
    # Resize video frames
    nframes, _, H, W = video_tensor.shape
    min_pixels = VIDEO_MIN_PIXELS
    total_pixels = VIDEO_TOTAL_PIXELS
    max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / max(1, nframes) * FRAME_FACTOR),
                     int(min_pixels * 1.05))
    rH, rW = smart_resize(H, W, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
    video_tensor = transforms.functional.resize(
        video_tensor, [rH, rW], interpolation=InterpolationMode.BICUBIC, antialias=True
    ).float()
    
    # Build messages
    messages = [{"role": "user", "content": [{"type": "video", "video": video_path}, {"type": "text", "text": prompt}]}]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], videos=[video_tensor], padding=True, return_tensors="pt")
    inputs = inputs.to(model.device)
    
    # Count prompt tokens
    prompt_tokens = int(inputs.input_ids.shape[1])
    
    # Generate
    gen_ids = model.generate(**inputs, **generation_kwargs)
    trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, gen_ids)]
    out_texts = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    output_text = out_texts[0] if out_texts else ""
    
    output_tokens = int(trimmed[0].shape[0]) if trimmed and hasattr(trimmed[0], "shape") else 0
    
    return {
        "text": output_text,
        "tokens": {"prompt": prompt_tokens, "output": output_tokens, "total": prompt_tokens + output_tokens}
    }


# Global model cache (backward-compat)
_global_model = None
_global_processor = None
temp_dir = tempfile.TemporaryDirectory()


def get_qwen_response(model, processor, prompt: str, video_path: str, frame_indices: List[int], 
                     generation_kwargs: Dict) -> str:
    """
    Backward-compatible Qwen inference function (text only).

    Args:
        model: Qwen model.
        processor: Qwen processor.
        prompt: Prompt text.
        video_path: Video path.
        frame_indices: List of frame indices.
        generation_kwargs: Generation parameters.

    Returns:
        str: Generated text.
    """
    result = get_qwen_response_generic(model, processor, prompt, video_path, frame_indices, generation_kwargs)
    return result["text"]
