"""
Modified LLaVA-3D Architecture with DINO Support

This file provides modified LLaVA-3D classes that can work with DINO features
from VGGT instead of relying on CLIP encoding.
"""

from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import sys
import os

# Import original LLaVA-3D components
sys.path.append(os.path.dirname(__file__))
from llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from multimodal_projector.builder import build_vision_projector
from ..constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

# Import Vid-LLM DINO components
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'vid_llm'))
from models.dino_based_llava3d import DinoBasedVisionTower, DinoBasedVideoTower


class DinoLlavaMetaModel(LlavaMetaModel):
    """
    Modified LLaVA meta model that supports DINO features from VGGT.
    
    This class extends the original LlavaMetaModel to work with DINO features
    instead of CLIP features, enabling shared visual encoding.
    """

    def __init__(self, config):
        # Initialize parent class
        super(DinoLlavaMetaModel, self).__init__(config)
        
        # Override vision components with DINO-based versions
        if hasattr(config, "mm_vision_tower"):
            self.dino_vision_tower = DinoBasedVisionTower(
                dino_dim=getattr(config, 'dino_dim', 1024),
                target_dim=getattr(config, 'mm_hidden_size', 1024),
                image_size=getattr(config, 'image_size', 336)
            )
            
            # Keep the original projector but ensure it works with DINO features
            if not hasattr(self, 'mm_projector'):
                self.mm_projector = build_vision_projector(config)

        if getattr(config, "mm_video_tower", None) is not None:
            self.dino_video_tower = DinoBasedVideoTower(
                dino_dim=getattr(config, 'dino_dim', 1024),
                target_dim=getattr(config, 'mm_hidden_size', 1024),
                num_frames=getattr(config, 'num_frames', 24)
            )

    def get_vision_tower(self):
        """Get DINO-based vision tower."""
        dino_vision_tower = getattr(self, 'dino_vision_tower', None)
        if dino_vision_tower is None:
            # Fallback to original vision tower
            return super().get_vision_tower()
        
        if type(dino_vision_tower) is list:
            dino_vision_tower = dino_vision_tower[0]
        return dino_vision_tower
    
    def get_video_tower(self):
        """Get DINO-based video tower."""
        dino_video_tower = getattr(self, 'dino_video_tower', None)
        if dino_video_tower is None:
            # Fallback to original video tower
            return super().get_video_tower()
        
        if type(dino_video_tower) is list:
            dino_video_tower = dino_video_tower[0]
        return dino_video_tower

    def get_prompt_encoder(self):
        """Get prompt encoder from DINO video tower."""
        dino_video_tower = self.get_video_tower()
        if dino_video_tower is not None and hasattr(dino_video_tower, 'prompt_encoder'):
            return dino_video_tower.prompt_encoder
        
        # Fallback to original implementation
        try:
            return super().get_prompt_encoder()
        except:
            return None


class DinoLlavaMetaForCausalLM(LlavaMetaForCausalLM):
    """
    Modified LLaVA causal LM that supports DINO features from VGGT.
    
    This class provides the interface for generation using shared DINO features
    instead of separate CLIP encoding.
    """

    def encode_images_with_dino(self, dino_features: torch.Tensor) -> torch.Tensor:
        """
        Encode images using pre-extracted DINO features from VGGT.
        
        Args:
            dino_features: DINO features from VGGT [B, num_patches, dino_dim]
            
        Returns:
            processed_features: [B, num_patches, target_dim]
        """
        dino_vision_tower = self.get_model().get_vision_tower()
        
        if dino_vision_tower is None:
            raise ValueError("DINO vision tower not available")
        
        # Process DINO features through adapted vision tower
        adapted_features = dino_vision_tower(dino_features)
        
        # Apply position embedding (set to zero since DINO already has positional info)
        dino_video_tower = self.get_model().get_video_tower()
        if dino_video_tower is not None and hasattr(dino_video_tower, 'video_tower'):
            B, num_patches, dim = adapted_features.shape
            pseudo_xyz = adapted_features.new_zeros((B, num_patches, 3))
            pos_embed = dino_video_tower.video_tower.encode_pe(pseudo_xyz)
            adapted_features = adapted_features + pos_embed * 0  # Keep original features
        
        # Apply multimodal projector
        projected_features = self.get_model().mm_projector(adapted_features)
        
        return projected_features

    def encode_rgbd_videos_with_dino(self, 
                                   dino_features: torch.Tensor,
                                   depths: torch.Tensor, 
                                   poses: torch.Tensor, 
                                   intrinsics: torch.Tensor, 
                                   lengths=None) -> tuple:
        """
        Encode RGBD videos using shared DINO features from VGGT.
        
        Args:
            dino_features: DINO features from VGGT [B*V, num_patches, dino_dim]
            depths: Depth maps [B, V, H, W]
            poses: Camera poses [B, V, 4, 4]
            intrinsics: Camera intrinsics [B, V, 4, 4]
            lengths: Optional sequence lengths
            
        Returns:
            video_features: Processed 3D-aware features
            batch_offset: Batch offset for voxelization
        """
        dino_video_tower = self.get_model().get_video_tower()
        
        if dino_video_tower is None:
            raise ValueError("DINO video tower not available")
        
        # Process through DINO-based video tower
        video_features, batch_offset = dino_video_tower(
            dino_features, depths, poses, intrinsics, lengths
        )
        
        # Apply multimodal projector
        video_features = self.get_model().mm_projector(video_features)
        
        return video_features, batch_offset

    def prepare_inputs_labels_for_multimodal_dino(
        self, 
        input_ids, 
        position_ids, 
        attention_mask, 
        past_key_values, 
        labels,
        dino_features,     # DINO features instead of images
        depths=None, 
        poses=None, 
        intrinsics=None, 
        lengths=None,
        clicks=None, 
        image_sizes=None
    ):
        """
        Prepare inputs using DINO features instead of raw images.
        
        This method replaces the original prepare_inputs_labels_for_multimodal
        to work with pre-extracted DINO features from VGGT.
        """
        vision_tower = self.get_vision_tower()
        
        if vision_tower is None or dino_features is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels

        # Determine if we have video (3D) or image (2D) data
        has_3d_data = depths is not None and poses is not None and intrinsics is not None
        
        if has_3d_data:
            # Process DINO features for 3D video understanding
            video_features, batch_offset = self.encode_rgbd_videos_with_dino(
                dino_features, depths, poses, intrinsics, lengths
            )
            
            # Handle voxelization results
            if batch_offset is not None:
                # Process voxelized features
                image_features = []
                idx = 0
                for b in batch_offset:
                    feats = video_features[idx:b]
                    if feats.shape[0] > 2560:  # Limit max tokens for memory
                        indices = torch.randperm(feats.size(0))[:2560]
                        feats = feats[indices]
                    idx = b
                    image_features.append(feats)
            else:
                image_features = [video_features]
        else:
            # Process DINO features for 2D image understanding
            processed_features = self.encode_images_with_dino(dino_features)
            image_features = [processed_features]

        # Handle prompt encoding for clicks
        if clicks is None:
            pseudo_clicks = dino_features.new_zeros((0, 3))
            prompt_features = self.encode_prompts(pseudo_clicks) if self.get_prompt_encoder() else None
        else:
            prompt_features = self.encode_prompts(clicks) if self.get_prompt_encoder() else None

        # The rest follows the original LLaVA-3D multimodal preparation logic
        # but using DINO-derived features instead of CLIP features
        
        # For now, return modified inputs with DINO features
        return input_ids, position_ids, attention_mask, past_key_values, image_features, labels

    def generate_with_dino(self,
                         input_ids,
                         dino_features,
                         depths=None,
                         poses=None, 
                         intrinsics=None,
                         lengths=None,
                         clicks=None,
                         **generation_kwargs):
        """
        Generate response using DINO features.
        
        This method provides a direct interface for generation using DINO features
        extracted from VGGT, bypassing the need for CLIP encoding.
        """
        # Prepare inputs with DINO features
        (input_ids, position_ids, attention_mask, 
         past_key_values, image_features, labels) = self.prepare_inputs_labels_for_multimodal_dino(
            input_ids, None, None, None, None,
            dino_features, depths, poses, intrinsics, lengths, clicks
        )
        
        # Call the underlying model's generate method
        # This would integrate with the actual generation process
        with torch.inference_mode():
            output_ids = self.generate(
                input_ids,
                images=image_features[0] if image_features else None,
                **generation_kwargs
            )
        
        return output_ids 