"""
Unified Vision Converter for RGB3D Model

This converter handles the conversion between geometry and semantics models
while sharing Vision Encoder features between both models for maximum efficiency.
"""

import torch
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from .geometry_to_semantics import GeometryToSemanticsConverter
from ..models.vision_based_semantics import SharedVisionProcessor, VisionFeatureAdapter


class UnifiedVisionConverter(GeometryToSemanticsConverter):
    """
    Enhanced converter that shares vision features between geometry and semantics models.
    
    This converter extracts vision features once from geometry model and reuses them for semantics model,
    eliminating the need for separate visual encoding in semantics model.
    """
    
    def __init__(self, 
                 target_resolution: int = 336,
                 vision_dim: int = 1024,
                 semantics_dim: int = 1024):
        """
        Initialize Vision Encoder-shared converter.
        
        Args:
            target_resolution: Target resolution for Semantics Model
            vision_dim: Vision Encoder feature dimension from Geometry Predictor
            semantics_dim: Target Language Vision-compatible dimension for Semantics Model
        """
        super().__init__(target_resolution)
        
        self.vision_dim = vision_dim
        self.semantics_dim = semantics_dim
        
        # Initialize vision processor
        self.vision_processor = SharedVisionProcessor(target_resolution)
        
        # Initialize feature adapter
        self.feature_adapter = VisionFeatureAdapter(
            vision_dim=vision_dim,
            semantics_dim=semantics_dim,
            hidden_dim=max(2048, semantics_dim * 2),
            num_layers=3,
            dropout=0.1
        )
    
    def extract_shared_features(self, 
                              geometry_model,
                              rgb_images: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Extract shared Vision Encoder features from Geometry Predictor for both geometry and semantic tasks.
        
        Args:
            geometry_model: Geometry Predictor model instance
            rgb_images: Input RGB images [B, S, 3, H, W]
            
        Returns:
            Dictionary containing Vision Encoder features and Geometry Predictor predictions
        """
        return self.vision_processor.extract_vision_features_from_geometry(geometry_model, rgb_images)
    
    def convert_with_shared_dino(self, 
                               geometry_model,
                               rgb_images: torch.Tensor,
                               original_resolution: int = 518) -> Dict[str, torch.Tensor]:
        """
        Convert using shared Vision Encoder features for maximum efficiency.
        
        Args:
            geometry_model: Geometry Predictor model instance
            rgb_images: Input RGB images [B, S, 3, H, W]
            original_resolution: Original image resolution
            
        Returns:
            Dictionary containing both Geometry Predictor outputs and Semantics Model compatible inputs
        """
        # Extract shared Vision Encoder features and Geometry Predictor predictions
        shared_data = self.extract_shared_features(geometry_model, rgb_images)
        
        vision_features = shared_data['vision_features']
        geometry_predictions = shared_data['geometry_predictions']
        
        # Convert pose encoding to extrinsic/intrinsic matrices
        from geometry.utils.pose_enc import pose_encoding_to_extri_intri
        extrinsic, intrinsic = pose_encoding_to_extri_intri(
            geometry_predictions["pose_enc"], 
            rgb_images.shape[-2:]
        )
        
        # Organize Geometry Predictor outputs
        geometry_outputs = {
            'images': geometry_predictions["images"],
            'depth': geometry_predictions["depth"],
            'extrinsic': extrinsic,
            'intrinsic': intrinsic
        }
        
        # Convert Geometry Predictor outputs to Semantics Model format (geometry data)
        semantics_inputs = super().convert(geometry_outputs, original_resolution)
        
        # Prepare Vision Encoder features for Semantics Model
        B, S = rgb_images.shape[:2]
        
        # Resize Vision Encoder features to match target resolution
        resized_vision_features = self.dino_processor.resize_vision_features_for_resolution(
            vision_features, original_resolution, self.target_resolution
        )
        
        # Adapt Vision Encoder features to Language Vision-compatible format
        adapted_features = self.feature_adapter(resized_vision_features)
        
        # Reshape features for Semantics Model
        num_patches = adapted_features.shape[1]
        target_features = self.dino_processor.prepare_vision_features_for_semantics3d(
            adapted_features, B, S
        )
        
        # Add Vision Encoder features to Semantics Model inputs
        semantics_inputs['vision_features'] = target_features
        semantics_inputs['original_vision_features'] = vision_features
        
        return {
            'geometry_outputs': geometry_outputs,
            'semantics_inputs': semantics_inputs,
            'shared_vision_features': vision_features,
            'adapted_features': adapted_features
        }
    
    def validate_dino_conversion(self, 
                               shared_data: Dict[str, torch.Tensor]) -> bool:
        """
        Validate Vision Encoder feature sharing conversion.
        
        Args:
            shared_data: Shared conversion results
            
        Returns:
            True if conversion is valid
        """
        try:
            # Check basic conversion validity
            geometry_outputs = shared_data['geometry_outputs']
            semantics_inputs = shared_data['semantics_inputs']
            
            if not super().validate_conversion(geometry_outputs, semantics_inputs):
                return False
            
            # Check Vision Encoder feature validity
            vision_features = shared_data['shared_vision_features']
            adapted_features = shared_data['adapted_features']
            
            # Check dimensions
            if vision_features.shape[-1] != self.vision_dim:
                return False
            
            if adapted_features.shape[-1] != self.semantics_dim:
                return False
            
            # Check that feature adaptation preserves batch and sequence info
            if vision_features.shape[:-1] != adapted_features.shape[:-1]:
                return False
            
            return True
            
        except Exception:
            return False


class UnifiedModelPipeline:
    """
    Unified pipeline that coordinates Vision Encoder feature sharing between Geometry Predictor and Semantics Model.
    
    This is the main interface for using shared Vision Encoder features throughout Vid-LLM.
    """
    
    def __init__(self, 
                 target_resolution: int = 336,
                 vision_dim: int = 1024,
                 semantics_dim: int = 1024):
        """
        Initialize unified Vision Encoder pipeline.
        
        Args:
            target_resolution: Target resolution for Semantics Model
            vision_dim: Vision Encoder feature dimension from Geometry Predictor
            semantics_dim: Target dimension for Semantics Model compatibility
        """
        self.converter = UnifiedVisionConverter(
            target_resolution=target_resolution,
            vision_dim=vision_dim,
            semantics_dim=semantics_dim
        )
        
        self.target_resolution = target_resolution
        self.vision_dim = vision_dim
        self.semantics_dim = semantics_dim
    
    def process_rgb_sequence(self, 
                           geometry_model,
                           rgb_images: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Process RGB sequence through unified Vision Encoder pipeline.
        
        Args:
            geometry_model: Geometry Predictor model instance
            rgb_images: Input RGB images [B, S, 3, H, W]
            
        Returns:
            Unified processing results with shared Vision Encoder features
        """
        # Convert with shared Vision Encoder features
        results = self.converter.convert_with_shared_dino(geometry_model, rgb_images)
        
        # Validate conversion
        is_valid = self.converter.validate_dino_conversion(results)
        results['conversion_valid'] = is_valid
        
        return results
    
    def get_semantics_inputs(self, 
                         unified_results: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Extract Semantics Model compatible inputs from unified results.
        
        Args:
            unified_results: Results from process_rgb_sequence
            
        Returns:
            Semantics Model compatible inputs with Vision Encoder features
        """
        return unified_results['semantics_inputs']
    
    def get_geometry_outputs(self, 
                        unified_results: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Extract Geometry Predictor outputs from unified results.
        
        Args:
            unified_results: Results from process_rgb_sequence
            
        Returns:
            Geometry Predictor outputs
        """
        return unified_results['geometry_outputs']
    
    def get_shared_features(self, 
                          unified_results: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Extract shared Vision Encoder features from unified results.
        
        Args:
            unified_results: Results from process_rgb_sequence
            
        Returns:
            Shared Vision Encoder features
        """
        return unified_results['shared_vision_features'] 