"""
Modified Semantics Model Architecture for Vision Encoder Feature Integration

This module provides a modified version of Semantics Model's architecture that uses
Vision Encoder features from Geometry Predictor instead of Language Vision features.
"""

import torch
import torch.nn as nn
import sys
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple

# Add Semantics Model to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Semantics Model-main'))

try:
    from semantics.model.semantics_arch import SemanticsMetaModel, SemanticsMetaForCausalLM
    from semantics.model.multimodal_projector.builder import build_vision_projector
    from semantics.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
    LLAVA3D_ARCH_AVAILABLE = True
except ImportError as e:
    LLAVA3D_ARCH_AVAILABLE = False
    print(f"Warning: Semantics Model architecture not available: {e}")

from .dino_based_semantics3d import VisionBasedTower, VisionBasedVideoTower, VisionFeatureAdapter


class DinoSemanticsMetaModel:
    """
    Modified LLaVA meta model that uses Vision Encoder features instead of Language Vision.
    """
    
    def __init__(self, config):
        if LLAVA3D_ARCH_AVAILABLE:
            super(DinoSemanticsMetaModel, self).__init__(config)
        
        # Replace vision tower with Vision Encoder-based version
        if hasattr(config, "mm_vision_tower"):
            self.dino_vision_tower = VisionBasedTower(
                vision_dim=getattr(config, 'vision_dim', 1024),
                target_dim=getattr(config, 'mm_hidden_size', 1024),
                image_size=getattr(config, 'image_size', 336)
            )
            
            # Build projector for adapted features
            self.mm_projector = build_vision_projector(config)
            
            if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
                self.image_newline = nn.Parameter(
                    torch.empty(config.hidden_size, dtype=self.dtype if hasattr(self, 'dtype') else torch.float32)
                )

        # Replace video tower with Vision Encoder-based version
        if getattr(config, "mm_video_tower", None) is not None:
            self.dino_video_tower = VisionBasedVideoTower(
                vision_dim=getattr(config, 'vision_dim', 1024),
                target_dim=getattr(config, 'mm_hidden_size', 1024),
                num_frames=getattr(config, 'num_frames', 24)
            )

    def get_vision_tower(self):
        """Get Vision Encoder-based vision tower."""
        dino_vision_tower = getattr(self, 'dino_vision_tower', None)
        if type(dino_vision_tower) is list:
            dino_vision_tower = dino_vision_tower[0]
        return dino_vision_tower
    
    def get_video_tower(self):
        """Get Vision Encoder-based video tower."""
        dino_video_tower = getattr(self, 'dino_video_tower', None)
        if type(dino_video_tower) is list:
            dino_video_tower = dino_video_tower[0]
        return dino_video_tower

    def get_prompt_encoder(self):
        """Get prompt encoder from Vision Encoder video tower."""
        dino_video_tower = self.get_video_tower()
        if dino_video_tower is not None:
            prompt_encoder = dino_video_tower.prompt_encoder
            return prompt_encoder
        return None

    def initialize_dino_vision_modules(self, model_args, fsdp=None):
        """
        Initialize Vision Encoder-based vision modules.
        
        Args:
            model_args: Model arguments containing configuration
            fsdp: FSDP configuration
        """
        vision_dim = getattr(model_args, 'vision_dim', 1024)
        mm_hidden_size = getattr(model_args, 'mm_hidden_size', 1024)
        
        # Set configuration
        self.config.mm_vision_tower = "dino_shared"
        self.config.mm_video_tower = "dino_shared"
        self.config.use_mm_proj = True
        self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
        self.config.mm_hidden_size = mm_hidden_size
        self.config.vision_dim = vision_dim
        
        # Initialize Vision Encoder-based towers
        if not hasattr(self, 'dino_vision_tower'):
            self.dino_vision_tower = VisionBasedTower(
                vision_dim=vision_dim,
                target_dim=mm_hidden_size
            )
        
        if not hasattr(self, 'dino_video_tower'):
            self.dino_video_tower = VisionBasedVideoTower(
                vision_dim=vision_dim,
                target_dim=mm_hidden_size
            )
        
        print("✓ Vision Encoder-based vision modules initialized")


class DinoSemanticsMetaForCausalLM(ABC):
    """
    Modified LLaVA meta class for causal LM that uses Vision Encoder features.
    """

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def get_video_tower(self):
        return self.get_model().get_video_tower()

    def get_prompt_encoder(self):
        return self.get_model().get_prompt_encoder()

    def encode_images_with_dino(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Encode images using Vision Encoder features from Geometry Predictor.
        
        Args:
            vision_features: Vision Encoder features [B, num_patches, vision_dim]
            
        Returns:
            processed_features: [B, num_patches, target_dim]
        """
        vision_tower = self.get_vision_tower()
        if vision_tower is None:
            raise ValueError("Vision Encoder vision tower not initialized")
        
        # Process Vision Encoder features through adapted vision tower
        adapted_features = vision_tower(vision_features)
        
        # Apply multimodal projector
        projected_features = self.get_model().mm_projector(adapted_features)
        
        return projected_features

    def encode_depth_videos_with_dino(self, 
                                   vision_features: torch.Tensor,
                                   depths: torch.Tensor, 
                                   poses: torch.Tensor, 
                                   intrinsics: torch.Tensor, 
                                   lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode RGBD videos using Vision Encoder features from Geometry Predictor.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            depths: Depth maps [B, V, H, W]
            poses: Camera poses [B, V, 4, 4]
            intrinsics: Camera intrinsics [B, V, 4, 4]
            lengths: Optional sequence lengths [B]
            
        Returns:
            video_features: Processed 3D-aware features
            batch_offset: Batch offset for voxelization
        """
        video_tower = self.get_video_tower()
        if video_tower is None:
            raise ValueError("Vision Encoder video tower not initialized")
        
        # Process through Vision Encoder-based video tower
        video_features, batch_offset = video_tower(
            vision_features, depths, poses, intrinsics, lengths
        )
        
        # Apply multimodal projector
        video_features = self.get_model().mm_projector(video_features)
        
        return video_features, batch_offset

    def prepare_inputs_labels_for_multimodal_with_dino(
        self, 
        input_ids, 
        position_ids, 
        attention_mask, 
        past_key_values, 
        labels,
        vision_features: torch.Tensor,  # Vision Encoder features instead of images
        depths: torch.Tensor, 
        poses: torch.Tensor, 
        intrinsics: torch.Tensor, 
        lengths: Optional[torch.Tensor] = None,
        clicks: Optional[torch.Tensor] = None, 
        image_sizes: Optional[List] = None
    ):
        """
        Prepare inputs and labels for multimodal training/inference using Vision Encoder features.
        
        Args:
            input_ids: Tokenized text input
            position_ids: Position indices
            attention_mask: Attention mask
            past_key_values: Past key values for generation
            labels: Labels for training
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            depths: Depth maps [B, V, H, W]
            poses: Camera poses [B, V, 4, 4]
            intrinsics: Camera intrinsics [B, V, 4, 4]
            lengths: Optional sequence lengths [B]
            clicks: Optional click inputs
            image_sizes: Optional image sizes
            
        Returns:
            Prepared inputs for Semantics Model with Vision Encoder features
        """
        vision_tower = self.get_vision_tower()
        
        if vision_tower is None or vision_features is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels

        # Process Vision Encoder features for video understanding
        if depths is not None and poses is not None and intrinsics is not None:
            # Use Vision Encoder features for 3D understanding
            video_features, batch_offset = self.encode_depth_videos_with_dino(
                vision_features, depths, poses, intrinsics, lengths
            )
            
            # Handle batch processing
            if batch_offset is not None:
                # Voxelization method
                processed_features = []
                idx = 0
                for b in batch_offset:
                    feats = video_features[idx:b]
                    if feats.shape[0] > 2560:  # Limit max tokens
                        indices = torch.randperm(feats.size(0))[:2560]
                        feats = feats[indices]
                    idx = b
                    processed_features.append(feats)
                
                image_features = processed_features
            else:
                image_features = [video_features]
        else:
            # Use Vision Encoder features for 2D understanding
            processed_features = self.encode_images_with_dino(vision_features)
            image_features = [processed_features]

        # Handle prompt encoding
        if clicks is None:
            pseudo_clicks = vision_features.new_zeros((0, 3))
            prompt_features = self.get_prompt_encoder()(pseudo_clicks)
        else:
            prompt_features = self.get_prompt_encoder()(clicks)

        # Continue with standard Semantics Model processing
        # (This would integrate with the rest of Semantics Model's multimodal processing)
        
        return input_ids, position_ids, attention_mask, past_key_values, image_features, labels


class UnifiedVisionModel(nn.Module):
    """
    Unified Vid-LLM model that uses a single Vision Encoder encoder for both Geometry Predictor and Semantics Model.
    
    This is the main model class that coordinates the entire RGB → 3D understanding pipeline
    using shared Vision Encoder features.
    """
    
    def __init__(self, 
                 geometry_model,
                 semantics_model,
                 vision_dim: int = 1024,
                 semantics_dim: int = 1024,
                 target_resolution: int = 336):
        """
        Initialize unified Vid-LLM model.
        
        Args:
            geometry_model: Pre-trained Geometry Predictor model
            semantics_model: Pre-trained Semantics Model model  
            vision_dim: Vision Encoder feature dimension
            semantics_dim: Target Language Vision-compatible dimension
            target_resolution: Target resolution for processing
        """
        super().__init__()
        
        self.geometry_model = geometry_model
        self.semantics_model = semantics_model
        
        # Initialize unified pipeline
        from ..converters.dino_shared_converter import UnifiedModelPipeline
        self.unified_pipeline = UnifiedModelPipeline(
            target_resolution=target_resolution,
            vision_dim=vision_dim,
            semantics_dim=semantics_dim
        )
        
        # Replace Semantics Model's vision components with Vision Encoder-based ones
        self._replace_semantics3d_vision_components()
        
        print("✓ UnifiedVisionModel initialized with shared Vision Encoder features")
    
    def _replace_semantics3d_vision_components(self):
        """Replace Semantics Model's Language Vision components with Vision Encoder-based ones."""
        # Replace vision tower
        if hasattr(self.semantics_model, 'model') and hasattr(self.semantics_model.model, 'vision_tower'):
            self.semantics_model.model.vision_tower = self.unified_pipeline.converter.feature_adapter
        
        # Replace video tower if exists
        if hasattr(self.semantics_model, 'model') and hasattr(self.semantics_model.model, 'video_tower'):
            self.semantics_model.model.video_tower = VisionBasedVideoTower()
    
    def forward(self, 
                rgb_images: torch.Tensor,
                input_ids: torch.Tensor,
                query: str,
                **kwargs) -> str:
        """
        Forward pass for RGB-only 3D vision-language understanding.
        
        Args:
            rgb_images: Input RGB sequence [B, S, 3, H, W]
            input_ids: Tokenized text input
            query: Text query
            **kwargs: Additional arguments
            
        Returns:
            Generated response text
        """
        # Process through unified Vision Encoder pipeline
        unified_results = self.unified_pipeline.process_rgb_sequence(
            self.geometry_model, rgb_images
        )
        
        # Extract components
        geometry_outputs = self.unified_pipeline.get_geometry_outputs(unified_results)
        semantics_inputs = self.unified_pipeline.get_semantics_inputs(unified_results)
        shared_vision_features = self.unified_pipeline.get_shared_features(unified_results)
        
        # Use shared Vision Encoder features for Semantics Model inference
        # (This would call the modified Semantics Model with Vision Encoder features)
        
        # For now, return a placeholder
        return f"Processed {rgb_images.shape[1]} frames using shared Vision Encoder features"
    
    def generate_with_vision_features(self,
                                  rgb_images: torch.Tensor,
                                  query: str,
                                  **generation_kwargs) -> str:
        """
        Generate response using shared Vision Encoder features.
        
        Args:
            rgb_images: Input RGB sequence [B, S, 3, H, W]
            query: Text query
            **generation_kwargs: Generation parameters
            
        Returns:
            Generated response
        """
        # Process RGB through unified pipeline
        unified_results = self.unified_pipeline.process_rgb_sequence(
            self.geometry_model, rgb_images
        )
        
        if not unified_results['conversion_valid']:
            raise ValueError("Vision Encoder feature conversion failed validation")
        
        # Extract Vision Encoder features and geometry data
        vision_features = unified_results['adapted_features']  # [B*V, num_patches, semantics_dim]
        semantics_inputs = unified_results['semantics_inputs']
        
        # Prepare for Semantics Model generation
        depths = semantics_inputs['depths']
        poses = semantics_inputs['poses']
        intrinsics = semantics_inputs['intrinsics']
        
        # Process through modified Semantics Model using Vision Encoder features
        response = self._generate_with_semantics3d_dino(
            vision_features, depths, poses, intrinsics, query, **generation_kwargs
        )
        
        return response
    
    def _generate_with_semantics3d_dino(self,
                                  vision_features: torch.Tensor,
                                  depths: torch.Tensor,
                                  poses: torch.Tensor, 
                                  intrinsics: torch.Tensor,
                                  query: str,
                                  **kwargs) -> str:
        """
        Generate response using Semantics Model with Vision Encoder features.
        
        This method would interface with the modified Semantics Model model that
        accepts Vision Encoder features instead of running Language Vision encoding.
        """
        # This is where the actual Semantics Model inference would happen
        # using the shared Vision Encoder features instead of Language Vision features
        
        # Placeholder implementation
        B_times_V, num_patches, dim = vision_features.shape
        total_tokens = B_times_V * num_patches
        
        return f"Generated response using {total_tokens} Vision Encoder tokens from shared encoder for query: '{query}'" 