
import math
from typing import Optional, Tuple, List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VLAttention,
    Qwen2_5_VLFlashAttention2, 
    Qwen2_5_VLSdpaAttention,
    Qwen2_5_VLDecoderLayer,
    Qwen2_5_VLModel,
    Qwen2_5_VLForConditionalGeneration,
    Qwen2_5_VLCausalLMOutputWithPast,
    apply_multimodal_rotary_pos_emb,
    repeat_kv,
    QWEN2_5_VL_ATTENTION_CLASSES,
)
from transformers.utils import is_flash_attn_2_available, logging
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import DynamicCache
from torch.nn import CrossEntropyLoss

if is_flash_attn_2_available():
    from transformers.modeling_flash_attention_utils import _flash_attention_forward

logger = logging.get_logger(__name__)


class MoHMixin:
    """Simple MoH functionality mixin."""
    
    def __init__(self, config, layer_idx=None, **kwargs):
        # Handle layer_idx parameter gracefully
        if layer_idx is not None:
            super().__init__(config, layer_idx, **kwargs)
        else:
            super().__init__(config, **kwargs)
        # Enable MoH only for the last 15 decoder layers (excluding final layer)
        try:
            total_layers = int(getattr(config, "num_hidden_layers", 0))
        except Exception:
            total_layers = 0
        start_moh_layer = max(0, total_layers - 15) if total_layers > 0 else 0
        apply_on_this_layer = self.layer_idx is not None and self.layer_idx >= start_moh_layer and self.layer_idx < total_layers - 1

        self.moh_enabled = getattr(config, "enable_moh", False) and apply_on_this_layer
        # Using vision-end to visual threshold: <vision_end> token's attention to visual tokens as threshold
        
        # Cache for complete input_ids sequence during generation
        self._cached_full_input_ids = None
        self._vision_token_positions = None  # Cache vision token positions
        
        # Training monitoring
        self._step_count = 0
        self._log_interval = 100  # Log every 100 steps
        
    def _moh_select_heads(self, query_states: torch.Tensor, input_ids: torch.Tensor = None, 
                         attn_weights: torch.Tensor = None) -> torch.Tensor:
        """Select heads where visual attention max > threshold."""
        if not self.moh_enabled:
            return None
            
        with torch.no_grad():
            # Start with all heads disabled
            mask = torch.zeros(self.num_heads, device=query_states.device, dtype=torch.bool)
            
            # Handle input_ids caching for generation
            effective_input_ids = input_ids
            
            # If this is a single-token generation step, use cached full sequence
            if input_ids is not None and input_ids.shape[-1] == 1:
                if self._cached_full_input_ids is not None:
                    # Append new token to cached sequence
                    effective_input_ids = torch.cat([self._cached_full_input_ids, input_ids], dim=-1)
                    self._cached_full_input_ids = effective_input_ids  # Update cache
                else:
                    # Single token without cache - fallback to all heads
                    mask[:] = True
                    return mask
            else:
                # Multi-token input (initial processing) - cache it
                if input_ids is not None:
                    self._cached_full_input_ids = input_ids.clone()
            
            # Fallback to all heads if visual attention analysis fails
            if effective_input_ids is None or attn_weights is None:
                mask[:] = True  # Keep all heads as fallback
                return mask
            
            # Vision token IDs from Qwen2.5-VL config
            vision_start_token_id = 151652
            vision_end_token_id = 151653
            
            try:
                batch_size = effective_input_ids.shape[0]
                head_scores = []
                
                for batch_idx in range(batch_size):
                    batch_input_ids = effective_input_ids[batch_idx]
                    batch_attn = attn_weights[batch_idx]  # [heads, seq_len, seq_len]
                    
                    # Find vision token positions
                    vision_start_pos = torch.where(batch_input_ids == vision_start_token_id)[0]
                    vision_end_pos = torch.where(batch_input_ids == vision_end_token_id)[0]
                    
                    if len(vision_start_pos) == 0 or len(vision_end_pos) == 0:
                        continue
                        
                    start_pos = vision_start_pos[0] + 1  # Skip vision_start token
                    end_pos = vision_end_pos[0]  # Vision_end position
                    
                    if start_pos >= end_pos:
                        continue
                    
                    # Use last token's attention to visual tokens
                    last_token_pos = batch_attn.shape[-2] - 1  # Last query token position (seq_len dimension)
                    
                    if last_token_pos < 0 or batch_attn.shape[-2] == 0:
                        continue
                    
                    # Check if visual token positions are within the key dimension
                    key_seq_len = batch_attn.shape[-1]  # Last dimension is key sequence length
                    if start_pos >= key_seq_len or end_pos > key_seq_len:
                        continue
                    
                    # Extract last token's raw attention to visual tokens
                    # batch_attn shape: [heads, query_seq_len, key_seq_len]
                    # visual_attention shape: [heads, visual_token_len]
                    visual_attention = batch_attn[:, last_token_pos, start_pos:end_pos]
                    
                    # Compute head scores as max raw attention to visual tokens
                    batch_head_scores = visual_attention.max(dim=-1)[0]  # [heads]
                    head_scores.append(batch_head_scores)
                
                if head_scores:
                    # Average across batches
                    avg_head_scores = torch.stack(head_scores).mean(dim=0)  # [heads]
                    
                    # Raw vision-end token threshold: vision_end token's attention to visual tokens as threshold
                    # This leverages vision_end as summary token attending back to visual content
                    adaptive_threshold = 0.0  # Default fallback
                    
                    try:
                        for batch_idx in range(batch_size):
                            batch_input_ids = effective_input_ids[batch_idx]
                            batch_attn = attn_weights[batch_idx]  # [heads, seq_len, seq_len]
                            
                            # Find vision token positions
                            vision_start_pos = torch.where(batch_input_ids == vision_start_token_id)[0]
                            vision_end_pos = torch.where(batch_input_ids == vision_end_token_id)[0]
                            
                            if len(vision_start_pos) > 0 and len(vision_end_pos) > 0:
                                start_pos = vision_start_pos[0] + 1  # Skip vision_start token
                                end_pos = vision_end_pos[0]  # Vision_end position
                                vision_end_query_pos = end_pos  # vision_end as query
                                
                                if (start_pos < end_pos and 
                                    vision_end_query_pos < batch_attn.shape[-2] and 
                                    end_pos <= batch_attn.shape[-1]):
                                    # Get vision_end token's raw attention to visual tokens
                                    vision_end_to_visual = batch_attn[:, vision_end_query_pos, start_pos:end_pos]  # [heads, visual_len]
                                    # Use max attention from vision_end to visual tokens as threshold
                                    batch_threshold = vision_end_to_visual.max().item()
                                    adaptive_threshold = max(adaptive_threshold, batch_threshold)
                    except Exception:
                        pass
                    
                    # Select heads where raw visual attention > vision_end→visual threshold
                    mask = avg_head_scores > adaptive_threshold
                    # Always ensure head 0 is selected (as shared/common head)
                    mask[0] = True
                    
                    # Training monitoring: log periodically
                    self._step_count += 1
                    if hasattr(self, 'layer_idx') and self._step_count % self._log_interval == 0:
                        selected_count = mask.sum().item()
                        max_score = avg_head_scores.max().item()
                        min_score = avg_head_scores.min().item()
                        logger.info(f"[MoH Training] Layer {self.layer_idx}: Step {self._step_count}, Selected {selected_count}/{self.num_heads} heads, Max score: {max_score:.4f}, Min score: {min_score:.4f}, Vision-end→visual threshold: {adaptive_threshold:.4f}")
                    
                    return mask
                    
            except Exception as e:
                # Fallback to all heads
                self._step_count += 1
                if hasattr(self, 'layer_idx') and self._step_count % (self._log_interval * 10) == 0:  # Log exceptions less frequently
                    logger.warning(f"[MoH Training] Layer {self.layer_idx}: Step {self._step_count}, Exception in head selection, using all heads")
            
            # Default fallback: keep all heads
            mask[:] = True
            self._step_count += 1
            if hasattr(self, 'layer_idx') and self._step_count % (self._log_interval * 5) == 0:  # Log fallbacks less frequently
                logger.warning(f"[MoH Training] Layer {self.layer_idx}: Step {self._step_count}, Using fallback (all heads)")
        return mask

    def _moh_select_heads_visual_only(self, query_states: torch.Tensor, key_states: torch.Tensor, 
                                     input_ids: torch.Tensor = None) -> torch.Tensor:
        """Efficiently select heads where visual attention max > threshold."""
        if not self.moh_enabled:
            return None
            
        if input_ids is None:
            return None
            
        with torch.no_grad():
            # Handle input_ids caching for generation (same as _moh_select_heads)
            effective_input_ids = input_ids
            
            # If this is a single-token generation step, use cached full sequence
            if input_ids.shape[-1] == 1:
                if self._cached_full_input_ids is not None:
                    # Append new token to cached sequence
                    effective_input_ids = torch.cat([self._cached_full_input_ids, input_ids], dim=-1)
                    self._cached_full_input_ids = effective_input_ids  # Update cache
                else:
                    # Single token without cache - fallback to all heads
                    return None  # Return None to keep all heads
            else:
                # Multi-token input (initial processing) - cache it
                self._cached_full_input_ids = input_ids.clone()
            
            # Start with all heads disabled
            mask = torch.zeros(self.num_heads, device=query_states.device, dtype=torch.bool)
            
            vision_start_token_id = 151652
            vision_end_token_id = 151653
            
            try:
                batch_size = effective_input_ids.shape[0]
                head_scores = []
                
                for batch_idx in range(batch_size):
                    batch_input_ids = effective_input_ids[batch_idx]
                    
                    # Find vision token positions
                    vision_start_pos = torch.where(batch_input_ids == vision_start_token_id)[0]
                    vision_end_pos = torch.where(batch_input_ids == vision_end_token_id)[0]
                    
                    if len(vision_start_pos) == 0 or len(vision_end_pos) == 0:
                        continue
                        
                    start_pos = vision_start_pos[0] + 1  # Skip vision_start token
                    end_pos = vision_end_pos[0]  # Vision_end position
                    
                    if start_pos >= end_pos:
                        continue
                    
                    # Check if visual token positions are within the key dimension
                    key_seq_len = key_states.shape[-2]  # Sequence length dimension
                    if start_pos >= key_seq_len or end_pos > key_seq_len:
                        continue
                    
                    # Get last token query and visual keys
                    last_token_query = query_states[batch_idx, :, -1:, :]  # [heads, 1, head_dim]
                    visual_keys = key_states[batch_idx, :, start_pos:end_pos, :]  # [heads, visual_len, head_dim]
                    
                    # Compute raw visual attention scores (no softmax)
                    visual_scores = torch.matmul(last_token_query, visual_keys.transpose(-2, -1))
                    visual_scores = visual_scores / math.sqrt(self.head_dim)  # [heads, 1, visual_len]
                    
                    # Max raw attention to visual tokens for each head
                    batch_head_scores = visual_scores.max(dim=-1)[0].squeeze(1)  # [heads]
                    head_scores.append(batch_head_scores)
                
                if head_scores:
                    # Average across batches
                    avg_head_scores = torch.stack(head_scores).mean(dim=0)  # [heads]
                    
                    # Raw vision-end token threshold: vision_end token's attention to visual tokens as threshold
                    # This leverages vision_end as summary token attending back to visual content
                    adaptive_threshold = 0.0  # Default fallback
                    
                    try:
                        for batch_idx in range(batch_size):
                            batch_input_ids = effective_input_ids[batch_idx]
                            
                            # Find vision token positions
                            vision_start_pos = torch.where(batch_input_ids == vision_start_token_id)[0]
                            vision_end_pos = torch.where(batch_input_ids == vision_end_token_id)[0]
                            
                            if len(vision_start_pos) > 0 and len(vision_end_pos) > 0:
                                start_pos = vision_start_pos[0] + 1  # Skip vision_start token
                                end_pos = vision_end_pos[0]  # Vision_end position
                                vision_end_query_pos = end_pos  # vision_end as query
                                
                                # Check positions are valid
                                key_seq_len = key_states.shape[-2]
                                if (start_pos < end_pos and 
                                    vision_end_query_pos < key_seq_len and 
                                    end_pos <= key_seq_len):
                                    # Get vision_end query and visual keys
                                    vision_end_query = query_states[batch_idx, :, vision_end_query_pos:vision_end_query_pos+1, :]  # [heads, 1, head_dim]
                                    visual_keys = key_states[batch_idx, :, start_pos:end_pos, :]  # [heads, visual_len, head_dim]
                                    
                                    # Compute vision_end token's raw attention to visual tokens
                                    vision_end_to_visual_scores = torch.matmul(vision_end_query, visual_keys.transpose(-2, -1))
                                    vision_end_to_visual_scores = vision_end_to_visual_scores / math.sqrt(self.head_dim)  # [heads, 1, visual_len]
                                    
                                    # Use max attention from vision_end to visual tokens as threshold
                                    batch_threshold = vision_end_to_visual_scores.max().item()
                                    adaptive_threshold = max(adaptive_threshold, batch_threshold)
                    except Exception:
                        pass
                    
                    # Select heads where raw visual attention > vision_end→visual threshold
                    mask = avg_head_scores > adaptive_threshold
                    # Always ensure head 0 is selected (as shared/common head)
                    mask[0] = True
                    
                    # Training monitoring: log periodically (visual_only method)
                    self._step_count += 1
                    if hasattr(self, 'layer_idx') and self._step_count % self._log_interval == 0:
                        selected_count = mask.sum().item()
                        max_score = avg_head_scores.max().item()
                        min_score = avg_head_scores.min().item()
                        logger.info(f"[MoH Training] Layer {self.layer_idx} (visual_only): Step {self._step_count}, Selected {selected_count}/{self.num_heads} heads, Max score: {max_score:.4f}, Min score: {min_score:.4f}, Vision-end→visual threshold: {adaptive_threshold:.4f}")
                    
                    return mask
                    
            except Exception as e:
                # Fallback to all heads
                self._step_count += 1
                if hasattr(self, 'layer_idx') and self._step_count % (self._log_interval * 10) == 0:  # Log exceptions less frequently
                    logger.warning(f"[MoH Training] Layer {self.layer_idx} (visual_only): Step {self._step_count}, Exception in head selection, using all heads")
            
            # Default fallback: keep all heads
            mask[:] = True
            self._step_count += 1
            if hasattr(self, 'layer_idx') and self._step_count % (self._log_interval * 5) == 0:  # Log fallbacks less frequently
                logger.warning(f"[MoH Training] Layer {self.layer_idx} (visual_only): Step {self._step_count}, Using fallback (all heads)")
        return mask


class MoHQwen2_5_VLAttention(MoHMixin, Qwen2_5_VLAttention):
    """MoH Standard Attention - direct transformers inheritance."""
    
    def forward(self, hidden_states, attention_mask=None, position_ids=None, 
                past_key_value=None, output_attentions=False, use_cache=False, 
                cache_position=None, position_embeddings=None, input_ids=None):
        
        # Standard forward pass
        bsz, q_len, _ = hidden_states.size()
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings
        cos, sin = position_embeddings
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        )

        # Cache handling
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # Expand KV
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # Standard attention computation
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # MoH head selection using raw attention weights (before softmax)
        head_mask = self._moh_select_heads(query_states, input_ids, attn_weights)
        
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        
        # Apply MoH mask to attention weights
        if head_mask is not None:
            gating = head_mask.to(dtype=attn_weights.dtype).view(1, -1, 1, 1)
            attn_weights = attn_weights * gating
        
        attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        # Output projection
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output, None if not output_attentions else attn_weights, past_key_value


class MoHQwen2_5_VLFlashAttention2(MoHMixin, Qwen2_5_VLFlashAttention2):
    """MoH Flash Attention - direct transformers inheritance."""
    
    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, output_attentions=False, use_cache=False,
                cache_position=None, position_embeddings=None, input_ids=None):
        
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings
        cos, sin = position_embeddings
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        )

        # Cache handling
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # Expand KV
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # MoH head selection - efficient visual-only computation
        head_mask = self._moh_select_heads_visual_only(query_states, key_states, input_ids)

        # Flash attention dtype handling
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Transpose for flash attention
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        # Sliding window
        sliding_window = None
        if (self.config.use_sliding_window and 
            getattr(self.config, "sliding_window", None) is not None and
            self.layer_idx >= self.config.max_window_layers):
            sliding_window = self.config.sliding_window

        # Flash attention forward
        dropout_rate = 0.0 if not self.training else self.attention_dropout
        attn_output = _flash_attention_forward(
            query_states, key_states, value_states, attention_mask, q_len,
            dropout=dropout_rate, sliding_window=sliding_window, 
            is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask,
        )

        # Apply MoH gating
        if head_mask is not None:
            gating = head_mask.to(dtype=attn_output.dtype).view(1, 1, -1, 1)
            attn_output = attn_output * gating

        # Output projection
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


class MoHQwen2_5_VLSdpaAttention(MoHMixin, Qwen2_5_VLSdpaAttention):
    """MoH SDPA Attention - direct transformers inheritance."""
    
    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, output_attentions=False, use_cache=False,
                cache_position=None, position_embeddings=None, input_ids=None):
        
        if output_attentions:
            # Fallback to eager for attention weights
            return super(MoHQwen2_5_VLSdpaAttention, self).forward(
                hidden_states, attention_mask, position_ids, past_key_value,
                output_attentions, use_cache, cache_position, position_embeddings
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings
        cos, sin = position_embeddings
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        )

        # Cache handling
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # Expand KV
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # MoH head selection - efficient visual-only computation
        head_mask = self._moh_select_heads_visual_only(query_states, key_states, input_ids)

        # SDPA computation
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] if attention_mask is not None else None
        
        attn_output = F.scaled_dot_product_attention(
            query_states, key_states, value_states, attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0, is_causal=False,
        )

        # Apply MoH gating
        if head_mask is not None:
            gating = head_mask.to(dtype=attn_output.dtype).view(1, -1, 1, 1)
            attn_output = attn_output * gating

        # Output projection
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


# MoH Attention Classes Mapping
MOH_ATTENTION_CLASSES = {
    "eager": MoHQwen2_5_VLAttention,
    "flash_attention_2": MoHQwen2_5_VLFlashAttention2,
    "sdpa": MoHQwen2_5_VLSdpaAttention,
}


class MoHQwen2_5_VLDecoderLayer(Qwen2_5_VLDecoderLayer):
    """MoH Decoder Layer - minimal modification."""
    
    def __init__(self, config, layer_idx: int):
        super().__init__(config, layer_idx)
        # Replace attention with MoH version
        attention_class = MOH_ATTENTION_CLASSES.get(config._attn_implementation, MoHQwen2_5_VLAttention)
        self.self_attn = attention_class(config, layer_idx)
        


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        input_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """Forward with MoH: plumb input_ids to attention for head selection.

        The signature mirrors upstream while adding input_ids.
        """
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            input_ids=input_ids,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


class MoHQwen2_5_VLModel(Qwen2_5_VLModel):
    """MoH Model - direct transformers inheritance with layer replacement."""
    
    def __init__(self, config):
        # Call parent __init__ first
        super().__init__(config)
        # Replace all layers with MoH layers
        self.layers = torch.nn.ModuleList([
            MoHQwen2_5_VLDecoderLayer(config, layer_idx) 
            for layer_idx in range(config.num_hidden_layers)
        ])
        # Post-init after layer replacement
        self.post_init()
        
        # Log MoH initialization for training
        total_layers = len(self.layers)
        moh_layers = sum(1 for i in range(total_layers) if i >= total_layers - 15 and i < total_layers - 1)
        logger.info(f"[MoH Training] Model initialized: {moh_layers}/{total_layers} layers with MoH enabled (vision-end→visual threshold)")


    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        # Mirror upstream logic, but pass input_ids into each decoder layer.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Allow both input_ids and inputs_embeds (we need input_ids for MoH head selection)
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You must specify at least one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # torch.jit.trace() doesn't support cache objects in the output
        if use_cache and past_key_values is None and not torch.jit.is_tracing():
            past_key_values = DynamicCache()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # the hard coded `3` is for temporal, height and width.
        if position_ids is None:
            position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
        elif position_ids.dim() == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                    input_ids,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    input_ids=input_ids,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class MoHQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
    """MoH Conditional Generation - direct transformers inheritance."""
    
    def __init__(self, config):
        # Call parent init with all necessary args
        super().__init__(config)
        # Replace model with MoH version
        self.model = MoHQwen2_5_VLModel(config)
        # Mark as MoH model
        self._is_moh_model = True

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
        # Mirror upstream, but ensure we retain input_ids and pass to self.model
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
                n_image_features = image_embeds.shape[0]
                if n_image_tokens != n_image_features:
                    raise ValueError(
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
                    )
                mask = input_ids == self.config.image_token_id
                inputs_embeds = inputs_embeds.masked_scatter(mask.unsqueeze(-1).expand_as(inputs_embeds), image_embeds.to(inputs_embeds.device, inputs_embeds.dtype))

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )
                mask = input_ids == self.config.video_token_id
                inputs_embeds = inputs_embeds.masked_scatter(mask.unsqueeze(-1).expand_as(inputs_embeds), video_embeds.to(inputs_embeds.device, inputs_embeds.dtype))

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            if (
                (cache_position is not None and cache_position[0] == 0)
                or getattr(self, "rope_deltas", None) is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        outputs = self.model(
            input_ids=input_ids,  # keep input_ids for MoH
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            logits = logits.float()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Qwen2_5_VLCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=getattr(self, "rope_deltas", None),
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        pixel_values=None,
        pixel_values_videos=None,
        image_grid_thw=None,
        video_grid_thw=None,
        second_per_grid_ts=None,
        **kwargs,
    ):
        # Delegate to upstream to get all standard behaviors
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            position_ids=position_ids,
            use_cache=use_cache,
            pixel_values=pixel_values,
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            **kwargs,
        )
        # Critical MoH fix: keep input_ids alongside inputs_embeds during prefill so MoH can see tokens
        if model_inputs.get("inputs_embeds") is not None and input_ids is not None:
            model_inputs["input_ids"] = input_ids
        return model_inputs