"""
Vision Encoder-based Semantics Model Components

This module modifies Semantics Model to use Vision Encoder features from Geometry Predictor instead of Language Vision,
creating a unified vision encoding pipeline for the entire Vid-LLM system.
"""

import torch
import torch.nn as nn
import sys
import os
from typing import Dict, List, Optional, Tuple

# Add Semantics Model to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Semantics Model-main'))

try:
    from semantics.model.multimodal_encoder.spatial_aware_module import SpatialAwareModule
    from semantics.model.multimodal_encoder.unproject import backprojector_dataloader, voxelize
    from semantics.model.multimodal_encoder.position_encodings import PositionEmbeddingLearnedMLP
    from torch_scatter import scatter_mean
    LLAVA3D_COMPONENTS_AVAILABLE = True
except ImportError as e:
    LLAVA3D_COMPONENTS_AVAILABLE = False
    print(f"Warning: Semantics Model components not available: {e}")


class VisionFeatureAdapter(nn.Module):
    """
    Adapter layer to convert Vision Encoder features to Language Vision-compatible format for Semantics Model.
    
    This MLP adapts Vision Encoder features from Geometry Predictor to match the expected feature dimensions
    and characteristics that Semantics Model expects from Language Vision.
    """
    
    def __init__(self, 
                 vision_dim: int = 1024,           # Geometry Predictor Vision Encoder feature dimension
                 semantics_dim: int = 1024,           # Semantics Model expected Language Vision dimension
                 hidden_dim: int = 2048,         # Hidden layer dimension
                 num_layers: int = 2,            # Number of MLP layers
                 dropout: float = 0.1):          # Dropout rate
        """
        Initialize Vision Encoder to Language Vision feature adapter.
        
        Args:
            vision_dim: Input Vision Encoder feature dimension from Geometry Predictor
            semantics_dim: Output dimension to match Language Vision features
            hidden_dim: Hidden layer dimension
            num_layers: Number of MLP layers
            dropout: Dropout rate for regularization
        """
        super().__init__()
        
        self.vision_dim = vision_dim
        self.semantics_dim = semantics_dim
        
        # Build MLP layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(vision_dim, hidden_dim))
        layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.GELU())
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
        
        # Output layer
        layers.append(nn.Linear(hidden_dim, semantics_dim))
        
        self.adapter = nn.Sequential(*layers)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize adapter weights."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Convert Vision Encoder features to Language Vision-compatible format.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            
        Returns:
            clip_compatible_features: [B*V, num_patches, semantics_dim]
        """
        return self.adapter(vision_features)


class VisionBasedTower(nn.Module):
    """
    Modified vision tower that uses Vision Encoder features from Geometry Predictor instead of Language Vision.
    
    This replaces the original VisionTower in Semantics Model, allowing the system
    to use a single Vision Encoder encoder for both geometry (Geometry Predictor) and semantics (Semantics Model).
    """
    
    def __init__(self, 
                 vision_dim: int = 1024,
                 target_dim: int = 1024,
                 patch_size: int = 14,
                 image_size: int = 336):
        """
        Initialize Vision Encoder-based vision tower.
        
        Args:
            vision_dim: Vision Encoder feature dimension from Geometry Predictor
            target_dim: Target feature dimension for Semantics Model
            patch_size: Patch size (should match Vision Encoder's patch size)
            image_size: Input image size
        """
        super().__init__()
        
        self.is_loaded = True  # Always loaded since we use external Vision Encoder features
        self.vision_dim = vision_dim
        self.target_dim = target_dim
        self.patch_size = patch_size
        self.image_size = image_size
        
        # Feature adapter to convert Vision Encoder to Language Vision-like features
        self.feature_adapter = VisionFeatureAdapter(
            vision_dim=vision_dim,
            semantics_dim=target_dim,
            hidden_dim=max(2048, target_dim * 2),
            num_layers=3,
            dropout=0.1
        )
        
        # Mock image processor for compatibility
        self.image_processor = type('MockProcessor', (), {
            'image_mean': [0.485, 0.456, 0.406],  # ImageNet normalization
            'image_std': [0.229, 0.224, 0.225],
            'size': {'shortest_edge': image_size},
            'crop_size': {'height': image_size, 'width': image_size}
        })()
        
        print(f"✓ VisionBasedTower initialized: {vision_dim} → {target_dim}")
    
    def load_model(self, device_map=None):
        """Load model (no-op since we use external Vision Encoder features)."""
        pass
    
    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Process Vision Encoder features to Language Vision-compatible format.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            
        Returns:
            adapted_features: [B*V, num_patches, target_dim]
        """
        # Adapt Vision Encoder features to Language Vision-compatible format
        adapted_features = self.feature_adapter(vision_features)
        return adapted_features
    
    @property
    def config(self):
        """Mock config for compatibility."""
        return type('MockConfig', (), {
            'hidden_size': self.target_dim,
            'image_size': self.image_size,
            'patch_size': self.patch_size,
        })()
    
    @property
    def hidden_size(self):
        return self.target_dim
    
    @property
    def num_patches_per_side(self):
        return self.image_size // self.patch_size
    
    @property
    def num_patches(self):
        return (self.image_size // self.patch_size) ** 2
    
    @property
    def dtype(self):
        return next(self.parameters()).dtype
    
    @property
    def device(self):
        return next(self.parameters()).device


class VisionBasedVideoTower(nn.Module):
    """
    Modified video tower that works with Vision Encoder features from Geometry Predictor.
    
    This replaces the original DepthVideoTower to work with shared Vision Encoder features.
    """
    
    def __init__(self, 
                 vision_dim: int = 1024,
                 target_dim: int = 1024,
                 num_frames: int = 24,
                 voxel_size: float = 0.2):
        """
        Initialize Vision Encoder-based video tower.
        
        Args:
            vision_dim: Vision Encoder feature dimension from Geometry Predictor
            target_dim: Target feature dimension
            num_frames: Maximum number of frames
            voxel_size: Voxel size for 3D pooling
        """
        super().__init__()
        
        self.is_loaded = True
        self.num_frames = num_frames
        self.pooling = 'voxelize'
        self.voxel_size = voxel_size
        self.vision_dim = vision_dim
        self.target_dim = target_dim
        
        # Load spatial aware module
        if LLAVA3D_COMPONENTS_AVAILABLE:
            self.video_tower = SpatialAwareModule(latent_dim=target_dim)
        else:
            # Mock implementation
            self.video_tower = nn.Identity()
        
        # Feature adapter for Vision Encoder features
        self.feature_adapter = VisionFeatureAdapter(
            vision_dim=vision_dim,
            semantics_dim=target_dim,
            hidden_dim=max(2048, target_dim * 2),
            num_layers=2
        )
        
        # Prompt encoder
        if LLAVA3D_COMPONENTS_AVAILABLE:
            self.prompt_encoder = type('PromptEncoder', (nn.Module,), {
                'latent_dim': target_dim,
                'pos_emb3d': PositionEmbeddingLearnedMLP(dim=3, num_pos_feats=target_dim),
                'encode_pe': lambda self, xyz: self.pos_emb3d(xyz),
                'forward': lambda self, clicks: self.encode_pe(clicks)
            })()
        else:
            self.prompt_encoder = nn.Identity()
    
    def load_model(self, device_map=None):
        """Load model (no-op since we use external features)."""
        pass
    
    def forward(self, 
                vision_features: torch.Tensor,
                depths: torch.Tensor, 
                poses: torch.Tensor, 
                intrinsics: torch.Tensor, 
                lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass using Vision Encoder features from Geometry Predictor.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            depths: Depth maps [B, V, H, W]
            poses: Camera poses [B, V, 4, 4]
            intrinsics: Camera intrinsics [B, V, 4, 4]
            lengths: Optional sequence lengths [B]
            
        Returns:
            pooled_video_features: Pooled 3D features
            batch_offset: Batch offset for voxelization
        """
        # Adapt Vision Encoder features to target dimension
        adapted_features = self.feature_adapter(vision_features)  # [B*V, num_patches, target_dim]
        
        # Reshape to spatial format for 3D processing
        BV, num_patches, C = adapted_features.shape
        patches_per_side = int(num_patches ** 0.5)  # Assuming square patch grid
        
        # Reshape to spatial grid format
        features_spatial = adapted_features.view(BV, patches_per_side, patches_per_side, C)
        features_spatial = features_spatial.permute(0, 3, 1, 2)  # [B*V, C, H, W]
        
        # Reshape back to match expected input format for backprojection
        B = depths.shape[0]
        V = depths.shape[1]
        H, W = depths.shape[2], depths.shape[3]
        
        # Resize features to match depth map resolution if needed
        if features_spatial.shape[-2:] != (H, W):
            features_spatial = torch.nn.functional.interpolate(
                features_spatial, 
                size=(H, W), 
                mode='bilinear', 
                align_corners=False
            )
        
        # Reshape features for video processing
        features_reshaped = features_spatial.view(B, V, C, H, W)
        
        if LLAVA3D_COMPONENTS_AVAILABLE:
            # Use real Semantics Model 3D processing
            feat_xyz, xyz = backprojector_dataloader([features_spatial], depths, poses, intrinsics)
            video_features = self.video_tower([features_spatial], [feat_xyz.flatten(0, 1)], (B, V))[0]
            
            if lengths is not None:
                lengths = lengths * H * W
            
            if self.pooling == 'voxelize':
                p2v = voxelize(feat_xyz, self.voxel_size)
                pooled_video_features = torch.cat([
                    scatter_mean(video_features[b], p2v[b], dim=0) 
                    for b in range(len(video_features))
                ])
                batch_offset = ((p2v).max(1)[0] + 1).cumsum(0).to(torch.int32)
            else:
                raise NotImplementedError
        else:
            # Mock implementation for testing
            pooled_video_features = torch.randn(B * 1000, C, device=vision_features.device)
            batch_offset = torch.tensor([1000 * (i+1) for i in range(B)], dtype=torch.int32)
        
        return pooled_video_features, batch_offset
    
    @property
    def config(self):
        return type('MockConfig', (), {'hidden_size': self.target_dim})()
    
    @property
    def hidden_size(self):
        return self.target_dim
    
    @property
    def dtype(self):
        return next(self.parameters()).dtype
    
    @property
    def device(self):
        return next(self.parameters()).device


class VisionBasedSemanticsArch:
    """
    Modified Semantics Model architecture that uses Vision Encoder features from Geometry Predictor.
    
    This class provides the interface for Semantics Model to work with Vision Encoder features
    instead of Language Vision features, enabling unified vision encoding.
    """
    
    def __init__(self, config):
        self.dino_vision_tower = VisionBasedTower(
            vision_dim=getattr(config, 'vision_dim', 1024),
            target_dim=getattr(config, 'mm_hidden_size', 1024)
        )
        
        self.dino_video_tower = VisionBasedVideoTower(
            vision_dim=getattr(config, 'vision_dim', 1024),
            target_dim=getattr(config, 'mm_hidden_size', 1024),
            num_frames=getattr(config, 'num_frames', 24)
        )
    
    def get_vision_tower(self):
        """Get Vision Encoder-based vision tower."""
        return self.dino_vision_tower
    
    def get_video_tower(self):
        """Get Vision Encoder-based video tower."""
        return self.dino_video_tower
    
    def get_prompt_encoder(self):
        """Get prompt encoder from video tower."""
        return self.dino_video_tower.prompt_encoder
    
    def encode_images_with_dino(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Encode images using pre-extracted Vision Encoder features.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B, num_patches, vision_dim]
            
        Returns:
            adapted_features: [B, num_patches, target_dim]
        """
        adapted_features = self.dino_vision_tower(vision_features)
        
        # Add pseudo position embedding (set to zero since Vision Encoder already has positional info)
        if hasattr(self.dino_video_tower, 'video_tower') and hasattr(self.dino_video_tower.video_tower, 'encode_pe'):
            B, num_patches, dim = adapted_features.shape
            pseudo_xyz = adapted_features.new_zeros((B, num_patches, 3))
            pos_embed = self.dino_video_tower.video_tower.encode_pe(pseudo_xyz)
            adapted_features = adapted_features + pos_embed * 0  # Keep original features
        
        return adapted_features
    
    def encode_depth_videos_with_dino(self, 
                                   vision_features: torch.Tensor,
                                   depths: torch.Tensor, 
                                   poses: torch.Tensor, 
                                   intrinsics: torch.Tensor, 
                                   lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode RGBD videos using Vision Encoder features from Geometry Predictor.
        
        Args:
            vision_features: Vision Encoder features from Geometry Predictor [B*V, num_patches, vision_dim]
            depths: Depth maps [B, V, H, W]
            poses: Camera poses [B, V, 4, 4]  
            intrinsics: Camera intrinsics [B, V, 4, 4]
            lengths: Optional sequence lengths [B]
            
        Returns:
            video_features: Processed 3D-aware features
            batch_offset: Batch offset for voxelization
        """
        # Process through Vision Encoder-based video tower
        video_features, batch_offset = self.dino_video_tower(
            vision_features, depths, poses, intrinsics, lengths
        )
        
        return video_features, batch_offset


class SharedVisionProcessor:
    """
    Processor that coordinates Vision Encoder feature sharing between Geometry Predictor and Semantics Model.
    
    This class manages the extraction and sharing of Vision Encoder features from Geometry Predictor
    to be used by both the geometry prediction and language understanding tasks.
    """
    
    def __init__(self, target_resolution: int = 336):
        self.target_resolution = target_resolution
        
    def extract_vision_features_from_geometry(self, 
                                      geometry_model,
                                      rgb_images: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Extract Vision Encoder features from Geometry Predictor's aggregator.
        
        Args:
            geometry_model: Geometry Predictor model instance
            rgb_images: Input RGB images [B, S, 3, H, W]
            
        Returns:
            Dictionary containing Vision Encoder features and other Geometry Predictor outputs
        """
        # Run Geometry Predictor forward pass up to feature extraction
        with torch.no_grad():
            # Get aggregated tokens (Vision Encoder features) from Geometry Predictor
            aggregated_tokens_list, patch_start_idx = geometry_model.aggregator(rgb_images)
            
            # Extract patch tokens (exclude camera and register tokens)
            vision_features = []
            for tokens in aggregated_tokens_list:
                # Remove special tokens and keep only patch tokens
                patch_tokens = tokens[:, patch_start_idx:]  # [B*S, num_patches, dim]
                vision_features.append(patch_tokens)
            
            # Use features from the last layer
            final_vision_features = vision_features[-1]
            
            # Run full Geometry Predictor prediction for depth and pose
            geometry_predictions = geometry_model(rgb_images)
        
        return {
            'vision_features': final_vision_features,
            'geometry_predictions': geometry_predictions,
            'aggregated_tokens': aggregated_tokens_list,
            'patch_start_idx': patch_start_idx
        }
    
    def prepare_vision_features_for_semantics3d(self, 
                                        vision_features: torch.Tensor,
                                        batch_size: int,
                                        sequence_length: int) -> torch.Tensor:
        """
        Prepare Vision Encoder features for Semantics Model processing.
        
        Args:
            vision_features: Vision Encoder features [B*S, num_patches, dim]
            batch_size: Batch size
            sequence_length: Sequence length
            
        Returns:
            reshaped_features: [B, S, num_patches, dim]
        """
        # Reshape from [B*S, num_patches, dim] to [B, S, num_patches, dim]
        num_patches, dim = vision_features.shape[1], vision_features.shape[2]
        reshaped_features = vision_features.view(batch_size, sequence_length, num_patches, dim)
        
        return reshaped_features
    
    def resize_vision_features_for_resolution(self, 
                                          vision_features: torch.Tensor,
                                          original_resolution: int,
                                          target_resolution: int) -> torch.Tensor:
        """
        Resize Vision Encoder features to match target resolution.
        
        Args:
            vision_features: Vision Encoder features [B*V, num_patches, dim]
            original_resolution: Original image resolution
            target_resolution: Target image resolution
            
        Returns:
            resized_features: Features matching target resolution
        """
        if original_resolution == target_resolution:
            return vision_features
        
        BV, num_patches, dim = vision_features.shape
        original_patches_per_side = int(num_patches ** 0.5)
        
        # Reshape to spatial format
        features_spatial = vision_features.view(
            BV, original_patches_per_side, original_patches_per_side, dim
        ).permute(0, 3, 1, 2)  # [B*V, dim, H, W]
        
        # Calculate target patch grid size
        scale_factor = target_resolution / original_resolution
        target_patches_per_side = int(original_patches_per_side * scale_factor)
        
        # Resize features
        resized_features_spatial = torch.nn.functional.interpolate(
            features_spatial,
            size=(target_patches_per_side, target_patches_per_side),
            mode='bilinear',
            align_corners=False
        )
        
        # Reshape back to patch format
        resized_features = resized_features_spatial.permute(0, 2, 3, 1).view(
            BV, target_patches_per_side * target_patches_per_side, dim
        )
        
        return resized_features 