"""
Enhanced Adaptive Alpha and Gating Mechanisms for HPC

This module implements sophisticated adaptive alpha selection and gating
mechanisms described in the HPC paper, including:
- Multi-layer adaptive alpha networks
- Confidence-based gating
- Uncertainty-aware alpha adjustment
- Dynamic mixing parameter selection
- Ensemble-based gating strategies
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, Dict, List
import math


class EnhancedAdaptiveAlpha(nn.Module):
    """
    Enhanced adaptive alpha network with multiple gating strategies.
    
    This extends the basic adaptive alpha to include confidence-based gating,
    uncertainty estimation, and ensemble decision making.
    """
    
    def __init__(
        self,
        input_dim: int = 10,
        hidden_dims: List[int] = [64, 32],
        dropout_rate: float = 0.1,
        use_attention: bool = True,
        gating_strategy: str = "confidence_based"
    ):
        """
        Initialize enhanced adaptive alpha network.
        
        Args:
            input_dim: Dimensionality of input probabilities
            hidden_dims: List of hidden layer dimensions
            dropout_rate: Dropout rate for regularization
            use_attention: Whether to use attention mechanism
            gating_strategy: Type of gating ("confidence_based", "uncertainty_based", "ensemble")
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.gating_strategy = gating_strategy
        self.use_attention = use_attention
        
        # Build multi-layer network
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = hidden_dim
        
        # Final layer for alpha prediction
        layers.append(nn.Linear(prev_dim, 1))
        layers.append(nn.Sigmoid())  # Alpha in [0, 1]
        
        self.alpha_network = nn.Sequential(*layers)
        
        # Attention mechanism for feature importance
        if use_attention:
            self.attention = nn.MultiheadAttention(
                embed_dim=input_dim,
                num_heads=2,
                dropout=dropout_rate,
                batch_first=True
            )
            self.attention_norm = nn.LayerNorm(input_dim)
        
        # Confidence estimation network
        self.confidence_estimator = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        # Uncertainty estimation network
        self.uncertainty_estimator = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Positive uncertainty values
        )
        
        # Gate network for deciding whether to apply HPC
        self.gate_network = nn.Sequential(
            nn.Linear(input_dim + 2, 16),  # +2 for confidence and uncertainty
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    
    def forward(
        self,
        probabilities: torch.Tensor,
        return_components: bool = False
    ) -> Tuple[torch.Tensor, Optional[Dict]]:
        """
        Compute adaptive alpha values with gating.
        
        Args:
            probabilities: Model probabilities (batch_size, num_classes)
            return_components: Whether to return intermediate components
            
        Returns:
            (alpha_values, components_dict)
        """
        batch_size = probabilities.shape[0]
        
        # Apply attention if enabled
        if self.use_attention:
            # Add sequence dimension for attention
            prob_seq = probabilities.unsqueeze(1)  # (batch, 1, num_classes)
            attended_probs, attention_weights = self.attention(prob_seq, prob_seq, prob_seq)
            attended_probs = attended_probs.squeeze(1)  # (batch, num_classes)
            attended_probs = self.attention_norm(attended_probs + probabilities)
            features = attended_probs
        else:
            features = probabilities
            attention_weights = None
        
        # Compute base alpha values
        alpha_values = self.alpha_network(features).squeeze(-1)  # (batch_size,)
        
        # Estimate confidence and uncertainty
        confidence = self.confidence_estimator(features).squeeze(-1)
        uncertainty = self.uncertainty_estimator(features).squeeze(-1)
        
        # Apply gating strategy
        if self.gating_strategy == "confidence_based":
            # Lower alpha for high-confidence predictions
            alpha_values = alpha_values * (1.0 - confidence * 0.5)
            
        elif self.gating_strategy == "uncertainty_based":
            # Higher alpha for high-uncertainty predictions
            normalized_uncertainty = torch.sigmoid(uncertainty)
            alpha_values = alpha_values * (0.5 + normalized_uncertainty * 0.5)
            
        elif self.gating_strategy == "ensemble":
            # Use gate network to decide alpha scaling
            gate_input = torch.cat([features, confidence.unsqueeze(-1), uncertainty.unsqueeze(-1)], dim=-1)
            gate_values = self.gate_network(gate_input).squeeze(-1)
            alpha_values = alpha_values * gate_values
        
        # Ensure alpha values are in valid range
        alpha_values = torch.clamp(alpha_values, 0.0, 1.0)
        
        if return_components:
            components = {
                'base_alpha': self.alpha_network(features).squeeze(-1),
                'confidence': confidence,
                'uncertainty': uncertainty,
                'attention_weights': attention_weights,
                'final_alpha': alpha_values
            }
            return alpha_values, components
        
        return alpha_values, None


class UncertaintyAwareGating(nn.Module):
    """
    Uncertainty-aware gating mechanism that modulates HPC application
    based on model uncertainty and prediction confidence.
    """
    
    def __init__(
        self,
        input_dim: int = 10,
        temperature_scaling: bool = True,
        monte_carlo_samples: int = 10
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.temperature_scaling = temperature_scaling
        self.monte_carlo_samples = monte_carlo_samples
        
        # Uncertainty estimation network
        self.uncertainty_net = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Softplus()
        )
        
        # Temperature parameter for uncertainty calibration
        if temperature_scaling:
            self.temperature = nn.Parameter(torch.ones(1))
    
    def estimate_predictive_uncertainty(
        self,
        logits: torch.Tensor,
        method: str = "entropy"
    ) -> torch.Tensor:
        """
        Estimate predictive uncertainty using different methods.
        
        Args:
            logits: Model logits (batch_size, num_classes)
            method: Uncertainty estimation method ("entropy", "max_prob", "mutual_info")
            
        Returns:
            Uncertainty estimates (batch_size,)
        """
        probabilities = F.softmax(logits, dim=-1)
        
        if method == "entropy":
            # Predictive entropy
            entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-8), dim=-1)
            return entropy
            
        elif method == "max_prob":
            # Inverse of max probability (lower confidence = higher uncertainty)
            max_prob = torch.max(probabilities, dim=-1)[0]
            return 1.0 - max_prob
            
        elif method == "mutual_info":
            # Approximate mutual information using MC samples
            if self.training:
                # Enable dropout for MC sampling
                mc_predictions = []
                for _ in range(self.monte_carlo_samples):
                    mc_probs = F.softmax(logits, dim=-1)
                    mc_predictions.append(mc_probs)
                
                mc_predictions = torch.stack(mc_predictions, dim=0)  # (samples, batch, classes)
                
                # Mean prediction
                mean_pred = torch.mean(mc_predictions, dim=0)
                
                # Mutual information approximation
                entropy_mean = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8), dim=-1)
                mean_entropy = torch.mean(
                    -torch.sum(mc_predictions * torch.log(mc_predictions + 1e-8), dim=-1),
                    dim=0
                )
                mutual_info = entropy_mean - mean_entropy
                return mutual_info
            else:
                # Fall back to entropy for inference
                return self.estimate_predictive_uncertainty(logits, "entropy")
        
        else:
            raise ValueError(f"Unknown uncertainty method: {method}")
    
    def forward(
        self,
        logits: torch.Tensor,
        base_alpha: float = 0.3
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Compute uncertainty-modulated alpha values.
        
        Args:
            logits: Model logits (batch_size, num_classes)
            base_alpha: Base alpha value
            
        Returns:
            (modulated_alpha, uncertainty_info)
        """
        probabilities = F.softmax(logits, dim=-1)
        
        # Estimate different types of uncertainty
        entropy_uncertainty = self.estimate_predictive_uncertainty(logits, "entropy")
        confidence_uncertainty = self.estimate_predictive_uncertainty(logits, "max_prob")
        
        # Learn additional uncertainty features
        learned_uncertainty = self.uncertainty_net(probabilities).squeeze(-1)
        
        # Combine uncertainty estimates
        combined_uncertainty = (entropy_uncertainty + confidence_uncertainty + learned_uncertainty) / 3.0
        
        # Apply temperature scaling to uncertainty
        if self.temperature_scaling:
            combined_uncertainty = combined_uncertainty / self.temperature
        
        # Modulate alpha based on uncertainty
        # High uncertainty -> higher alpha (more human prior influence)
        normalized_uncertainty = torch.sigmoid(combined_uncertainty)
        modulated_alpha = base_alpha + (1.0 - base_alpha) * normalized_uncertainty * 0.5
        
        # Ensure valid range
        modulated_alpha = torch.clamp(modulated_alpha, 0.0, 1.0)
        
        uncertainty_info = {
            'entropy_uncertainty': entropy_uncertainty,
            'confidence_uncertainty': confidence_uncertainty,
            'learned_uncertainty': learned_uncertainty,
            'combined_uncertainty': combined_uncertainty,
            'normalized_uncertainty': normalized_uncertainty,
            'temperature': self.temperature.item() if self.temperature_scaling else 1.0
        }
        
        return modulated_alpha, uncertainty_info


class MultiScaleGating(nn.Module):
    """
    Multi-scale gating mechanism that considers both local (per-sample)
    and global (batch-level) patterns for alpha adjustment.
    """
    
    def __init__(
        self,
        input_dim: int = 10,
        global_context_dim: int = 32,
        use_batch_norm: bool = True
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.global_context_dim = global_context_dim
        
        # Local (per-sample) gating network
        self.local_gate = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        
        # Global context encoder
        self.global_encoder = nn.Sequential(
            nn.Linear(input_dim, global_context_dim),
            nn.ReLU(),
            nn.Linear(global_context_dim, global_context_dim // 2),
            nn.ReLU()
        )
        
        # Global context decoder
        self.global_decoder = nn.Sequential(
            nn.Linear(global_context_dim // 2, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        
        # Fusion network
        self.fusion_net = nn.Sequential(
            nn.Linear(2, 8),  # 2 inputs: local + global gates
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
        
        if use_batch_norm:
            self.batch_norm = nn.BatchNorm1d(input_dim)
        else:
            self.batch_norm = None
    
    def forward(
        self,
        probabilities: torch.Tensor,
        base_alpha: float = 0.3
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Compute multi-scale gated alpha values.
        
        Args:
            probabilities: Model probabilities (batch_size, num_classes)
            base_alpha: Base alpha value
            
        Returns:
            (gated_alpha, gating_info)
        """
        batch_size = probabilities.shape[0]
        
        # Apply batch normalization if enabled
        if self.batch_norm is not None:
            normalized_probs = self.batch_norm(probabilities)
        else:
            normalized_probs = probabilities
        
        # Local gating (per-sample)
        local_gates = self.local_gate(normalized_probs).squeeze(-1)  # (batch_size,)
        
        # Global context (batch-level patterns)
        global_features = self.global_encoder(normalized_probs)  # (batch_size, context_dim//2)
        
        # Aggregate global context (mean pooling)
        global_context = torch.mean(global_features, dim=0, keepdim=True)  # (1, context_dim//2)
        global_context = global_context.expand(batch_size, -1)  # (batch_size, context_dim//2)
        
        # Global gate
        global_gates = self.global_decoder(global_context).squeeze(-1)  # (batch_size,)
        
        # Fuse local and global gates
        gate_input = torch.stack([local_gates, global_gates], dim=-1)  # (batch_size, 2)
        fused_gates = self.fusion_net(gate_input).squeeze(-1)  # (batch_size,)
        
        # Apply gating to base alpha
        gated_alpha = base_alpha * fused_gates
        
        # Add adaptive component based on prediction confidence
        max_probs = torch.max(probabilities, dim=-1)[0]
        confidence_adjustment = (1.0 - max_probs) * 0.2  # Boost alpha for low-confidence predictions
        gated_alpha = gated_alpha + confidence_adjustment
        
        # Ensure valid range
        gated_alpha = torch.clamp(gated_alpha, 0.0, 1.0)
        
        gating_info = {
            'local_gates': local_gates,
            'global_gates': global_gates,
            'fused_gates': fused_gates,
            'confidence_adjustment': confidence_adjustment,
            'max_probs': max_probs
        }
        
        return gated_alpha, gating_info


class HierarchicalGating(nn.Module):
    """
    Hierarchical gating mechanism that applies different gating strategies
    at multiple levels (class-level, semantic-level, instance-level).
    """
    
    def __init__(
        self,
        input_dim: int = 10,
        semantic_groups: Optional[List[List[int]]] = None,
        use_class_specific: bool = True
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.semantic_groups = semantic_groups or []
        self.use_class_specific = use_class_specific
        
        # Instance-level gating
        self.instance_gate = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        # Class-specific gating (if enabled)
        if use_class_specific:
            self.class_gates = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(1, 8),
                    nn.ReLU(),
                    nn.Linear(8, 1),
                    nn.Sigmoid()
                ) for _ in range(input_dim)
            ])
        
        # Semantic group gating
        if semantic_groups:
            self.group_gates = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(len(group), 16),
                    nn.ReLU(),
                    nn.Linear(16, 1),
                    nn.Sigmoid()
                ) for group in semantic_groups
            ])
    
    def forward(
        self,
        probabilities: torch.Tensor,
        base_alpha: float = 0.3
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Apply hierarchical gating to compute final alpha values.
        
        Args:
            probabilities: Model probabilities (batch_size, num_classes)
            base_alpha: Base alpha value
            
        Returns:
            (hierarchical_alpha, gating_components)
        """
        batch_size = probabilities.shape[0]
        
        # Instance-level gating
        instance_gates = self.instance_gate(probabilities).squeeze(-1)
        
        # Class-specific gating
        class_gates = torch.ones_like(probabilities)  # Default to 1.0
        if self.use_class_specific:
            for class_idx in range(self.input_dim):
                class_prob = probabilities[:, class_idx:class_idx+1]
                class_gate = self.class_gates[class_idx](class_prob).squeeze(-1)
                class_gates[:, class_idx] = class_gate
        
        # Aggregate class gates (weighted by probabilities)
        weighted_class_gates = torch.sum(probabilities * class_gates, dim=-1)
        
        # Semantic group gating
        group_gates = torch.ones(batch_size, device=probabilities.device)
        if self.semantic_groups:
            group_influences = []
            for group_idx, group in enumerate(self.semantic_groups):
                group_probs = probabilities[:, group]
                group_gate = self.group_gates[group_idx](group_probs).squeeze(-1)
                group_influence = torch.sum(probabilities[:, group], dim=-1) * group_gate
                group_influences.append(group_influence)
            
            if group_influences:
                group_gates = torch.stack(group_influences, dim=-1).sum(dim=-1)
        
        # Combine all gating levels
        combined_gates = instance_gates * weighted_class_gates * group_gates
        combined_gates = torch.clamp(combined_gates, 0.0, 1.0)
        
        # Apply to base alpha
        hierarchical_alpha = base_alpha * combined_gates
        
        # Add uncertainty-based adjustment
        entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-8), dim=-1)
        normalized_entropy = entropy / math.log(self.input_dim)  # Normalize by max entropy
        uncertainty_boost = normalized_entropy * 0.1
        hierarchical_alpha = hierarchical_alpha + uncertainty_boost
        
        # Final clamping
        hierarchical_alpha = torch.clamp(hierarchical_alpha, 0.0, 1.0)
        
        gating_components = {
            'instance_gates': instance_gates,
            'weighted_class_gates': weighted_class_gates,
            'group_gates': group_gates,
            'combined_gates': combined_gates,
            'uncertainty_boost': uncertainty_boost,
            'entropy': entropy
        }
        
        return hierarchical_alpha, gating_components


# Convenience function for creating different gating strategies
def create_gating_mechanism(
    strategy: str,
    input_dim: int = 10,
    **kwargs
) -> nn.Module:
    """
    Factory function for creating different gating mechanisms.
    
    Args:
        strategy: Type of gating mechanism
        input_dim: Input dimensionality
        **kwargs: Additional arguments for specific mechanisms
        
    Returns:
        Initialized gating mechanism
    """
    if strategy == "enhanced_adaptive":
        return EnhancedAdaptiveAlpha(input_dim=input_dim, **kwargs)
    elif strategy == "uncertainty_aware":
        return UncertaintyAwareGating(input_dim=input_dim, **kwargs)
    elif strategy == "multi_scale":
        return MultiScaleGating(input_dim=input_dim, **kwargs)
    elif strategy == "hierarchical":
        return HierarchicalGating(input_dim=input_dim, **kwargs)
    else:
        raise ValueError(f"Unknown gating strategy: {strategy}")


# Example usage and testing
if __name__ == "__main__":
    print("Testing enhanced adaptive gating mechanisms...")
    
    # Create synthetic data
    batch_size = 32
    num_classes = 10
    
    # Generate synthetic probabilities
    logits = torch.randn(batch_size, num_classes)
    probabilities = F.softmax(logits, dim=-1)
    
    print(f"Testing with batch_size={batch_size}, num_classes={num_classes}")
    
    # Test different gating mechanisms
    strategies = ["enhanced_adaptive", "uncertainty_aware", "multi_scale", "hierarchical"]
    
    for strategy in strategies:
        print(f"\nTesting {strategy} gating...")
        
        try:
            # Create gating mechanism
            if strategy == "hierarchical":
                # Define semantic groups for CIFAR-10-like classes
                semantic_groups = [[0, 1, 8, 9], [2, 3, 4, 5, 6, 7]]  # animals vs vehicles/objects
                gating = create_gating_mechanism(
                    strategy, 
                    input_dim=num_classes,
                    semantic_groups=semantic_groups
                )
            else:
                gating = create_gating_mechanism(strategy, input_dim=num_classes)
            
            # Forward pass
            if hasattr(gating, 'forward'):
                if strategy == "enhanced_adaptive":
                    alpha_values, components = gating(probabilities, return_components=True)
                    print(f"  Alpha range: [{alpha_values.min():.3f}, {alpha_values.max():.3f}]")
                    if components:
                        print(f"  Mean confidence: {components['confidence'].mean():.3f}")
                        print(f"  Mean uncertainty: {components['uncertainty'].mean():.3f}")
                else:
                    alpha_values, info = gating(logits if 'uncertainty' in strategy else probabilities)
                    print(f"  Alpha range: [{alpha_values.min():.3f}, {alpha_values.max():.3f}]")
                    print(f"  Info keys: {list(info.keys())}")
            
            print(f"  {strategy} gating: ✓ Success")
            
        except Exception as e:
            print(f"  {strategy} gating: ✗ Failed - {e}")
    
    print("\nGating mechanism tests completed.")
