"""
Cross-Task Adaptor for Vid-LLM

This module implements a Cross-Task Adaptor that enables feature alignment and
cross-task communication between Geometry Predictor and Semantics Model through bridge tokens.

Architecture:
Vision Encoder Features → Dual MLP Alignment → Bridge Token Attention → Task-specific Features
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional
import math


class DualMLPAligner(nn.Module):
    """
    Dual MLP layers for aligning Vision Encoder features for Geometry Predictor and Semantics Model tasks.
    
    This module takes Vision Encoder features and produces two aligned feature representations
    optimized for geometry tasks (Geometry Predictor) and semantic tasks (Semantics Model).
    """
    
    def __init__(self,
                 vision_dim: int = 1024,
                 geometry_dim: int = 1024,
                 semantics3d_dim: int = 1024,
                 hidden_dim: int = 2048,
                 dropout: float = 0.1):
        """
        Initialize dual MLP aligner.
        
        Args:
            vision_dim: Input Vision Encoder feature dimension
            geometry_dim: Output dimension for Geometry Predictor branch
            semantics3d_dim: Output dimension for Semantics Model branch
            hidden_dim: Hidden layer dimension
            dropout: Dropout rate
        """
        super().__init__()
        
        self.vision_dim = vision_dim
        self.geometry_dim = geometry_dim
        self.semantics3d_dim = semantics3d_dim
        
        # MLP for Geometry Predictor branch (geometry-focused)
        self.geometry_mlp = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, geometry_dim),
            nn.LayerNorm(geometry_dim)
        )
        
        # MLP for Semantics Model branch (semantic-focused)
        self.semantics3d_mlp = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, semantics3d_dim),
            nn.LayerNorm(semantics3d_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for stable training."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, vision_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Align Vision Encoder features for both Geometry Predictor and Semantics Model tasks.
        
        Args:
            vision_features: Input Vision Encoder features [B*S, num_patches, vision_dim]
            
        Returns:
            geometry_features: Aligned features for Geometry Predictor [B*S, num_patches, geometry_dim]
            semantics3d_features: Aligned features for Semantics Model [B*S, num_patches, semantics3d_dim]
        """
        # Process through dual MLPs
        geometry_features = self.geometry_mlp(vision_features)
        semantics3d_features = self.semantics3d_mlp(vision_features)
        
        return geometry_features, semantics3d_features


class BridgeTokenAttention(nn.Module):
    """
    Bridge Token Attention mechanism for cross-task feature communication.
    
    This module introduces learnable bridge tokens that attend to both Geometry Predictor and
    Semantics Model features, enabling cross-task information exchange.
    """
    
    def __init__(self,
                 geometry_dim: int = 1024,
                 semantics3d_dim: int = 1024,
                 bridge_dim: int = 1024,
                 num_bridge_tokens: int = 8,
                 num_heads: int = 8,
                 dropout: float = 0.1):
        """
        Initialize bridge token attention.
        
        Args:
            geometry_dim: Geometry Predictor feature dimension
            semantics3d_dim: Semantics Model feature dimension
            bridge_dim: Bridge token dimension
            num_bridge_tokens: Number of bridge tokens
            num_heads: Number of attention heads
            dropout: Dropout rate
        """
        super().__init__()
        
        self.geometry_dim = geometry_dim
        self.semantics3d_dim = semantics3d_dim
        self.bridge_dim = bridge_dim
        self.num_bridge_tokens = num_bridge_tokens
        self.num_heads = num_heads
        
        # Initialize learnable bridge tokens
        self.bridge_tokens = nn.Parameter(
            torch.randn(1, num_bridge_tokens, bridge_dim) * 0.02
        )
        
        # Attention modules for bridge tokens
        self.geometry_bridge_attention = nn.MultiheadAttention(
            embed_dim=bridge_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.semantics3d_bridge_attention = nn.MultiheadAttention(
            embed_dim=bridge_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Feature projection layers to match bridge dimension
        self.geometry_proj = nn.Linear(geometry_dim, bridge_dim) if geometry_dim != bridge_dim else nn.Identity()
        self.semantics3d_proj = nn.Linear(semantics3d_dim, bridge_dim) if semantics3d_dim != bridge_dim else nn.Identity()
        
        # Output projection layers
        self.geometry_out_proj = nn.Linear(bridge_dim, geometry_dim) if geometry_dim != bridge_dim else nn.Identity()
        self.semantics3d_out_proj = nn.Linear(bridge_dim, semantics3d_dim) if semantics3d_dim != bridge_dim else nn.Identity()
        
        # Layer normalization
        self.geometry_norm = nn.LayerNorm(bridge_dim)
        self.semantics3d_norm = nn.LayerNorm(bridge_dim)
        
        print(f"✓ BridgeTokenAttention initialized: {num_bridge_tokens} bridge tokens")
    
    def forward(self, 
                geometry_features: torch.Tensor, 
                semantics3d_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Apply bridge token attention for cross-task communication.
        
        Args:
            geometry_features: Geometry Predictor features [B*S, num_patches, geometry_dim]
            semantics3d_features: Semantics Model features [B*S, num_patches, semantics3d_dim]
            
        Returns:
            enhanced_geometry_features: Geometry Predictor features with bridge tokens [B*S, num_patches + num_bridge, geometry_dim]
            enhanced_semantics3d_features: Semantics Model features with bridge tokens [B*S, num_patches + num_bridge, semantics3d_dim]
            bridge_tokens: Updated bridge tokens [B*S, num_bridge_tokens, bridge_dim]
        """
        BS, num_patches, _ = geometry_features.shape
        
        # Project features to bridge dimension
        geometry_proj_features = self.geometry_proj(geometry_features)  # [B*S, num_patches, bridge_dim]
        semantics3d_proj_features = self.semantics3d_proj(semantics3d_features)  # [B*S, num_patches, bridge_dim]
        
        # Expand bridge tokens for batch
        bridge_tokens = self.bridge_tokens.expand(BS, -1, -1)  # [B*S, num_bridge_tokens, bridge_dim]
        
        # Bridge tokens attend to Geometry Predictor features
        bridge_from_geometry, _ = self.geometry_bridge_attention(
            bridge_tokens,  # query
            geometry_proj_features,  # key
            geometry_proj_features   # value
        )
        bridge_from_geometry = self.geometry_norm(bridge_from_geometry + bridge_tokens)
        
        # Bridge tokens attend to Semantics Model features  
        bridge_from_semantics3d, _ = self.semantics3d_bridge_attention(
            bridge_tokens,  # query
            semantics3d_proj_features,  # key
            semantics3d_proj_features   # value
        )
        bridge_from_semantics3d = self.semantics3d_norm(bridge_from_semantics3d + bridge_tokens)
        
        # Combine bridge information (average for cross-task communication)
        enhanced_bridge_tokens = (bridge_from_geometry + bridge_from_semantics3d) / 2
        
        # Project bridge tokens back to task-specific dimensions
        geometry_bridge_proj = self.geometry_out_proj(enhanced_bridge_tokens)  # [B*S, num_bridge, geometry_dim]
        semantics3d_bridge_proj = self.semantics3d_out_proj(enhanced_bridge_tokens)  # [B*S, num_bridge, semantics3d_dim]
        
        # Concatenate bridge tokens with original features
        enhanced_geometry_features = torch.cat([geometry_features, geometry_bridge_proj], dim=1)
        enhanced_semantics3d_features = torch.cat([semantics3d_features, semantics3d_bridge_proj], dim=1)
        
        return enhanced_geometry_features, enhanced_semantics3d_features, enhanced_bridge_tokens


class CrossTaskAdaptor(nn.Module):
    """
    Cross-Task Adaptor for Vid-LLM
    
    This module implements the complete Cross-Task Adaptor architecture:
    1. Dual MLP alignment for task-specific feature preparation
    2. Bridge token attention for cross-task communication
    3. Feature enhancement for both Geometry Predictor and Semantics Model
    """
    
    def __init__(self,
                 vision_dim: int = 1024,
                 geometry_dim: int = 1024,
                 semantics3d_dim: int = 1024,
                 bridge_dim: int = 1024,
                 num_bridge_tokens: int = 8,
                 num_attention_heads: int = 8,
                 hidden_dim: int = 2048,
                 dropout: float = 0.1):
        """
        Initialize Cross-Task Adaptor.
        
        Args:
            vision_dim: Input Vision Encoder feature dimension
            geometry_dim: Output dimension for Geometry Predictor branch
            semantics3d_dim: Output dimension for Semantics Model branch
            bridge_dim: Bridge token dimension
            num_bridge_tokens: Number of bridge tokens for cross-task communication
            num_attention_heads: Number of attention heads in bridge attention
            hidden_dim: Hidden dimension for MLPs
            dropout: Dropout rate
        """
        super().__init__()
        
        self.vision_dim = vision_dim
        self.geometry_dim = geometry_dim
        self.semantics3d_dim = semantics3d_dim
        self.num_bridge_tokens = num_bridge_tokens
        
        # Stage 1: Dual MLP alignment
        self.dual_aligner = DualMLPAligner(
            vision_dim=vision_dim,
            geometry_dim=geometry_dim,
            semantics3d_dim=semantics3d_dim,
            hidden_dim=hidden_dim,
            dropout=dropout
        )
        
        # Stage 2: Bridge token attention
        self.bridge_attention = BridgeTokenAttention(
            geometry_dim=geometry_dim,
            semantics3d_dim=semantics3d_dim,
            bridge_dim=bridge_dim,
            num_bridge_tokens=num_bridge_tokens,
            num_heads=num_attention_heads,
            dropout=dropout
        )
        
        print(f"✓ CrossTaskAdaptor initialized:")
        print(f"  - Vision Encoder dim: {vision_dim}")
        print(f"  - Geometry Predictor dim: {geometry_dim}")
        print(f"  - Semantics Model dim: {semantics3d_dim}")
        print(f"  - Bridge tokens: {num_bridge_tokens}")
    
    def forward(self, vision_features: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass through Cross-Task Adaptor.
        
        Args:
            vision_features: Input Vision Encoder features [B*S, num_patches, vision_dim]
            
        Returns:
            Dictionary containing:
            - 'geometry_features': Enhanced features for Geometry Predictor [B*S, num_patches + num_bridge, geometry_dim]
            - 'semantics3d_features': Enhanced features for Semantics Model [B*S, num_patches + num_bridge, semantics3d_dim]
            - 'bridge_tokens': Updated bridge tokens [B*S, num_bridge_tokens, bridge_dim]
            - 'aligned_geometry': Task-aligned Geometry Predictor features [B*S, num_patches, geometry_dim]
            - 'aligned_semantics3d': Task-aligned Semantics Model features [B*S, num_patches, semantics3d_dim]
        """
        # Stage 1: Dual MLP alignment
        aligned_geometry, aligned_semantics3d = self.dual_aligner(vision_features)
        
        # Stage 2: Bridge token attention and enhancement
        enhanced_geometry, enhanced_semantics3d, bridge_tokens = self.bridge_attention(
            aligned_geometry, aligned_semantics3d
        )
        
        return {
            'geometry_features': enhanced_geometry,
            'semantics3d_features': enhanced_semantics3d,
            'bridge_tokens': bridge_tokens,
            'aligned_geometry': aligned_geometry,
            'aligned_semantics3d': aligned_semantics3d
        }
    
    def get_geometry_features(self, adaptor_outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Extract enhanced Geometry Predictor features."""
        return adaptor_outputs['geometry_features']
    
    def get_semantics3d_features(self, adaptor_outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Extract enhanced Semantics Model features."""
        return adaptor_outputs['semantics3d_features']
    
    def get_bridge_tokens(self, adaptor_outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Extract bridge tokens for analysis."""
        return adaptor_outputs['bridge_tokens']


class CrossTaskAttentionModule(nn.Module):
    """
    Advanced cross-task attention module for deeper feature interaction.
    
    This optional module can be added for even deeper cross-task communication
    beyond the basic bridge token mechanism.
    """
    
    def __init__(self,
                 geometry_dim: int = 1024,
                 semantics3d_dim: int = 1024,
                 num_heads: int = 8,
                 dropout: float = 0.1):
        """
        Initialize cross-task attention module.
        
        Args:
            geometry_dim: Geometry Predictor feature dimension
            semantics3d_dim: Semantics Model feature dimension
            num_heads: Number of attention heads
            dropout: Dropout rate
        """
        super().__init__()
        
        self.geometry_dim = geometry_dim
        self.semantics3d_dim = semantics3d_dim
        
        # Cross-attention: Geometry Predictor features attend to Semantics Model features
        self.geometry_to_semantics3d_attention = nn.MultiheadAttention(
            embed_dim=geometry_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Cross-attention: Semantics Model features attend to Geometry Predictor features
        self.semantics3d_to_geometry_attention = nn.MultiheadAttention(
            embed_dim=semantics3d_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Feature projection if dimensions don't match
        self.geometry_to_semantics3d_proj = nn.Linear(geometry_dim, semantics3d_dim) if geometry_dim != semantics3d_dim else nn.Identity()
        self.semantics3d_to_geometry_proj = nn.Linear(semantics3d_dim, geometry_dim) if semantics3d_dim != geometry_dim else nn.Identity()
        
        # Layer normalization
        self.geometry_norm = nn.LayerNorm(geometry_dim)
        self.semantics3d_norm = nn.LayerNorm(semantics3d_dim)
    
    def forward(self, 
                geometry_features: torch.Tensor, 
                semantics3d_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply cross-task attention.
        
        Args:
            geometry_features: Geometry Predictor features [B*S, num_patches, geometry_dim]
            semantics3d_features: Semantics Model features [B*S, num_patches, semantics3d_dim]
            
        Returns:
            cross_enhanced_geometry: Geometry Predictor features enhanced with Semantics Model information
            cross_enhanced_semantics3d: Semantics Model features enhanced with Geometry Predictor information
        """
        # Geometry Predictor features attend to Semantics Model features
        semantics3d_proj = self.semantics3d_to_geometry_proj(semantics3d_features)
        geometry_cross_attended, _ = self.geometry_to_semantics3d_attention(
            geometry_features,  # query
            semantics3d_proj,   # key
            semantics3d_proj    # value
        )
        enhanced_geometry = self.geometry_norm(geometry_features + geometry_cross_attended)
        
        # Semantics Model features attend to Geometry Predictor features
        geometry_proj = self.geometry_to_semantics3d_proj(geometry_features)
        semantics3d_cross_attended, _ = self.semantics3d_to_geometry_attention(
            semantics3d_features,  # query
            geometry_proj,          # key
            geometry_proj           # value
        )
        enhanced_semantics3d = self.semantics3d_norm(semantics3d_features + semantics3d_cross_attended)
        
        return enhanced_geometry, enhanced_semantics3d


class AdvancedCrossTaskAdaptor(nn.Module):
    """
    Advanced Cross-Task Adaptor with optional deep cross-attention.
    
    This extends the basic Cross-Task Adaptor with additional cross-attention
    mechanisms for even deeper task interaction.
    """
    
    def __init__(self,
                 vision_dim: int = 1024,
                 geometry_dim: int = 1024,
                 semantics3d_dim: int = 1024,
                 bridge_dim: int = 1024,
                 num_bridge_tokens: int = 8,
                 num_attention_heads: int = 8,
                 hidden_dim: int = 2048,
                 enable_cross_attention: bool = True,
                 dropout: float = 0.1):
        """
        Initialize advanced Cross-Task Adaptor.
        
        Args:
            vision_dim: Input Vision Encoder feature dimension
            geometry_dim: Output dimension for Geometry Predictor branch
            semantics3d_dim: Output dimension for Semantics Model branch
            bridge_dim: Bridge token dimension
            num_bridge_tokens: Number of bridge tokens
            num_attention_heads: Number of attention heads
            hidden_dim: Hidden dimension for MLPs
            enable_cross_attention: Whether to enable deep cross-attention
            dropout: Dropout rate
        """
        super().__init__()
        
        # Basic Cross-Task Adaptor
        self.basic_adaptor = CrossTaskAdaptor(
            vision_dim=vision_dim,
            geometry_dim=geometry_dim,
            semantics3d_dim=semantics3d_dim,
            bridge_dim=bridge_dim,
            num_bridge_tokens=num_bridge_tokens,
            num_attention_heads=num_attention_heads,
            hidden_dim=hidden_dim,
            dropout=dropout
        )
        
        # Optional deep cross-attention
        self.enable_cross_attention = enable_cross_attention
        if enable_cross_attention:
            self.cross_attention = CrossTaskAttentionModule(
                geometry_dim=geometry_dim,
                semantics3d_dim=semantics3d_dim,
                num_heads=num_attention_heads,
                dropout=dropout
            )
        
        print(f"✓ AdvancedCrossTaskAdaptor initialized:")
        print(f"  - Cross-attention: {'Enabled' if enable_cross_attention else 'Disabled'}")
    
    def forward(self, vision_features: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass through advanced Cross-Task Adaptor.
        
        Args:
            vision_features: Input Vision Encoder features [B*S, num_patches, vision_dim]
            
        Returns:
            Dictionary with enhanced features for both tasks
        """
        # Apply basic cross-task adaptation
        basic_results = self.basic_adaptor(vision_features)
        
        # Extract aligned features for potential cross-attention
        geometry_features = basic_results['aligned_geometry']
        semantics3d_features = basic_results['aligned_semantics3d']
        
        # Apply optional deep cross-attention
        if self.enable_cross_attention:
            cross_geometry, cross_semantics3d = self.cross_attention(geometry_features, semantics3d_features)
            
            # Add cross-attention results to the basic results
            basic_results['cross_attended_geometry'] = cross_geometry
            basic_results['cross_attended_semantics3d'] = cross_semantics3d
        
        return basic_results


class TaskSpecificHead(nn.Module):
    """
    Task-specific head that can be attached after Cross-Task Adaptor.
    
    This allows for task-specific processing after the cross-task feature enhancement.
    """
    
    def __init__(self,
                 input_dim: int = 1024,
                 output_dim: int = 1024,
                 num_layers: int = 2,
                 dropout: float = 0.1):
        """
        Initialize task-specific head.
        
        Args:
            input_dim: Input feature dimension
            output_dim: Output feature dimension
            num_layers: Number of processing layers
            dropout: Dropout rate
        """
        super().__init__()
        
        layers = []
        current_dim = input_dim
        
        for i in range(num_layers):
            if i == num_layers - 1:
                # Final layer
                layers.append(nn.Linear(current_dim, output_dim))
            else:
                # Hidden layers
                layers.append(nn.Linear(current_dim, current_dim))
                layers.append(nn.LayerNorm(current_dim))
                layers.append(nn.GELU())
                layers.append(nn.Dropout(dropout))
        
        self.head = nn.Sequential(*layers)
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Apply task-specific processing."""
        return self.head(features) 