
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any


class SegmentImportanceScorer(nn.Module):
    
    def __init__(self, feature_dim: int, scoring_method: str = 'energy', 
                 hidden_dim: int = 128, num_heads: int = 8):
        super().__init__()
        if scoring_method == 'probability':
            scoring_method = 'energy'
        self.scoring_method = scoring_method
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        
        # Energy-based scoring
        self.energy_net = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        # Attention-based scoring
        self.attention_net = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=num_heads,
            batch_first=True,
            dropout=0.1
        )
        self.importance_proj = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Entropy-based scoring
        self.entropy_net = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Gradient-based scoring
        self.gradient_net = nn.Sequential(
            nn.Linear(feature_dim * 2, hidden_dim),  
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, segments: torch.Tensor, 
                time_stamps: Optional[torch.Tensor] = None,
                additional_features: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
        if segments.dim() == 4:
            # [batch, n_segments, seq_len, feature_dim] -> [batch, n_segments, feature_dim]
            segments = segments.mean(dim=2)
        
        method = self.scoring_method
        if method == 'probability':
            method = 'energy'
        
        if method == 'energy':
            return self._energy_scoring(segments)
        elif method == 'attention':
            return self._attention_scoring(segments)
        # 'variance' method removed
        elif method == 'entropy':
            return self._entropy_scoring(segments)
        elif method == 'gradient':
            return self._gradient_scoring(segments, additional_features)
        else:
            raise ValueError(f"Unknown scoring method: {self.scoring_method}")
    
    def _energy_scoring(self, segments: torch.Tensor) -> torch.Tensor:
        energy_scores = self.energy_net(segments).squeeze(-1)  # [batch, n_segments]
        return torch.sigmoid(energy_scores)
    
    def _attention_scoring(self, segments: torch.Tensor) -> torch.Tensor:
        """Attention-based importance scoring"""
        attn_output, attn_weights = self.attention_net(
            segments, segments, segments
        )
        importance_scores = attn_weights.mean(dim=1)  # [batch, n_segments]
        
        final_scores = self.importance_proj(attn_output).squeeze(-1)  # [batch, n_segments]
        return torch.sigmoid(final_scores)
    
    # Variance-based scoring removed per requirement
    
    def _entropy_scoring(self, segments: torch.Tensor) -> torch.Tensor:
        """Entropy-based importance scoring"""
        probs = F.softmax(segments, dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)  # [batch, n_segments]
        
        entropy_scores = self.entropy_net(segments).squeeze(-1)  # [batch, n_segments]
        return torch.sigmoid(entropy_scores)
    
    def _gradient_scoring(self, segments: torch.Tensor, 
                         additional_features: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
        if additional_features and 'gradients' in additional_features:
            gradients = additional_features['gradients']
            combined_features = torch.cat([segments, gradients], dim=-1)
        else:
            combined_features = torch.cat([segments, segments], dim=-1)
        
        gradient_scores = self.gradient_net(combined_features).squeeze(-1)  # [batch, n_segments]
        return torch.sigmoid(gradient_scores)
    
    def get_scoring_method(self) -> str:
        return self.scoring_method
    
    def set_scoring_method(self, method: str):
        if method == 'probability':
            method = 'energy'
        if method not in ['energy', 'attention', 'entropy', 'gradient']:
            raise ValueError(f"Unknown scoring method: {method}")
        self.scoring_method = method


class DynamicComputeAllocator(nn.Module):
    
    def __init__(self, feature_dim: int = 256, top_k_ratio: float = 0.5, 
                 precision_levels: list = ['high', 'low'],
                 allocation_strategy: str = 'top_k'):
        super().__init__()
        self.feature_dim = feature_dim
        self.top_k_ratio = top_k_ratio
        self.precision_levels = precision_levels
        self.allocation_strategy = allocation_strategy
        
        self.precision_nets = nn.ModuleDict({
            'high': nn.Sequential(
                nn.Linear(feature_dim, feature_dim * 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(feature_dim * 2, feature_dim)
            ),
            'low': nn.Sequential(
                nn.Linear(feature_dim, feature_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(feature_dim // 2, feature_dim)
            )
        })
    
    def forward(self, segments: torch.Tensor, 
                importance_scores: torch.Tensor) -> Dict[str, Any]:
        batch_size, n_segments = importance_scores.shape
        
        if self.allocation_strategy == 'top_k':
            return self._top_k_allocation(segments, importance_scores)
        elif self.allocation_strategy == 'threshold':
            return self._threshold_allocation(segments, importance_scores)
        elif self.allocation_strategy == 'continuous':
            return self._continuous_allocation(segments, importance_scores)
        else:
            raise ValueError(f"Unknown allocation strategy: {self.allocation_strategy}")
    
    def _top_k_allocation(self, segments: torch.Tensor, 
                         importance_scores: torch.Tensor) -> Dict[str, Any]:
        batch_size, n_segments = importance_scores.shape
        k = max(1, int(n_segments * self.top_k_ratio))
        
        _, top_k_indices = torch.topk(importance_scores, k, dim=-1)  # [batch, k]
        
        allocation_mask = torch.zeros_like(importance_scores, dtype=torch.bool)
        for i in range(batch_size):
            allocation_mask[i, top_k_indices[i]] = True

        high_precision_segments = segments.clone()
        low_precision_segments = segments.clone()
        
        for i in range(batch_size):
            high_mask = allocation_mask[i]
            low_mask = ~allocation_mask[i]
            
            if high_mask.any():
                hp_out = self.precision_nets['high'](
                    high_precision_segments[i, high_mask]
                )
                high_precision_segments[i, high_mask] = hp_out.to(high_precision_segments.dtype)
            
            if low_mask.any():
                lp_out = self.precision_nets['low'](
                    low_precision_segments[i, low_mask]
                )
                low_precision_segments[i, low_mask] = lp_out.to(low_precision_segments.dtype)
        
        return {
            'allocation_mask': allocation_mask,
            'high_precision_segments': high_precision_segments,
            'low_precision_segments': low_precision_segments,
            'strategy': 'top_k',
            'k': k
        }
    
    def _threshold_allocation(self, segments: torch.Tensor, 
                            importance_scores: torch.Tensor) -> Dict[str, Any]:

        threshold = torch.median(importance_scores, dim=-1, keepdim=True)[0]
        allocation_mask = importance_scores >= threshold
        
        high_precision_segments = segments.clone()
        low_precision_segments = segments.clone()
        
        for i in range(segments.shape[0]):
            high_mask = allocation_mask[i]
            low_mask = ~allocation_mask[i]
            
            if high_mask.any():
                hp_out = self.precision_nets['high'](
                    high_precision_segments[i, high_mask]
                )
                high_precision_segments[i, high_mask] = hp_out.to(high_precision_segments.dtype)
            
            if low_mask.any():
                lp_out = self.precision_nets['low'](
                    low_precision_segments[i, low_mask]
                )
                low_precision_segments[i, low_mask] = lp_out.to(low_precision_segments.dtype)
        
        return {
            'allocation_mask': allocation_mask,
            'high_precision_segments': high_precision_segments,
            'low_precision_segments': low_precision_segments,
            'strategy': 'threshold',
            'threshold': threshold
        }
    
    def _continuous_allocation(self, segments: torch.Tensor, 
                             importance_scores: torch.Tensor) -> Dict[str, Any]:
        normalized_scores = F.softmax(importance_scores, dim=-1)
        
        high_precision_output = self.precision_nets['high'](segments)
        low_precision_output = self.precision_nets['low'](segments)
        
        weighted_output = (normalized_scores.unsqueeze(-1) * high_precision_output + 
                          (1 - normalized_scores.unsqueeze(-1)) * low_precision_output)
        
        return {
            'weighted_output': weighted_output,
            'importance_weights': normalized_scores,
            'strategy': 'continuous'
        }


def test_segment_importance_scorer():
    batch_size = 2
    n_segments = 8
    feature_dim = 256
    
    segments = torch.randn(batch_size, n_segments, feature_dim)
    
    methods = ['energy', 'attention', 'variance', 'entropy']
    
    for method in methods:
        scorer = SegmentImportanceScorer(feature_dim=feature_dim, scoring_method=method)
        scores = scorer(segments)
        
        print(f"{method} scoring:")
        print(f"  Input shape: {segments.shape}")
        print(f"  Output shape: {scores.shape}")
        print(f"  Score range: [{scores.min().item():.4f}, {scores.max().item():.4f}]")
        print()


def test_dynamic_compute_allocator():
    batch_size = 2
    n_segments = 8
    feature_dim = 256
    
    segments = torch.randn(batch_size, n_segments, feature_dim)
    importance_scores = torch.rand(batch_size, n_segments)
    
    strategies = ['top_k', 'threshold', 'continuous']
    
    for strategy in strategies:
        allocator = DynamicComputeAllocator(allocation_strategy=strategy)
        allocation_info = allocator(segments, importance_scores)
        
        print(f"{strategy} allocation:")
        print(f"  Strategy: {allocation_info['strategy']}")
        if 'allocation_mask' in allocation_info:
            print(f"  High precision segments: {allocation_info['allocation_mask'].sum().item()}")
        print()


if __name__ == "__main__":
    print("Testing SegmentImportanceScorer...")
    test_segment_importance_scorer()
    
    print("Testing DynamicComputeAllocator...")
    test_dynamic_compute_allocator() 