import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from abc import ABC, abstractmethod
from typing import Optional, List, Union


class PoolingStrategy(ABC):
    """Abstract base class for pooling strategies."""
    
    @abstractmethod
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Apply pooling to the hidden states.
        
        Args:
            last_hidden_state: [batch_size, seq_len, hidden_size]
            attention_mask: [batch_size, seq_len]
        
        Returns:
            pooled_output: [batch_size, hidden_size]
        """
        pass
    
    def setup(self, hidden_size: int):
        """Optional setup method for pooling strategies that need parameters."""
        pass


class MeanPooling(PoolingStrategy):
    """Mean pooling over sequence length."""
    
    def __init__(self, **kwargs):
        super().__init__()
        # Accept but ignore any extra kwargs for compatibility
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask


class MaxPooling(PoolingStrategy):
    """Max pooling over sequence length."""
    
    def __init__(self, **kwargs):
        super().__init__()
        # Accept but ignore any extra kwargs for compatibility
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        masked_embeddings = last_hidden_state.clone()
        masked_embeddings[input_mask_expanded == 0] = -1e9
        max_embeddings, _ = torch.max(masked_embeddings, dim=1)
        return max_embeddings


class SumPooling(PoolingStrategy):
    """Sum pooling over sequence length."""
    
    def __init__(self, **kwargs):
        super().__init__()
        # Accept but ignore any extra kwargs for compatibility
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        return sum_embeddings


class LastTokenPooling(PoolingStrategy):
    """Use the last non-padding token for each sequence."""
    
    def __init__(self, **kwargs):
        super().__init__()
        # Accept but ignore any extra kwargs for compatibility
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, hidden_size = last_hidden_state.shape
        
        # Find the last non-padding token for each sequence
        sequence_lengths = attention_mask.sum(dim=1) - 1  # -1 because of 0-indexing
        sequence_lengths = sequence_lengths.clamp(min=0, max=seq_len-1)
        
        # Gather the last token embeddings
        batch_indices = torch.arange(batch_size, device=last_hidden_state.device)
        last_token_embeddings = last_hidden_state[batch_indices, sequence_lengths]
        
        return last_token_embeddings


class FirstTokenPooling(PoolingStrategy):
    """Use the first token (typically [CLS] for BERT-like models)."""
    
    def __init__(self, **kwargs):
        super().__init__()
        # Accept but ignore any extra kwargs for compatibility
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return last_hidden_state[:, 0]  # First token


class AttentionPooling(PoolingStrategy, nn.Module):
    """Learned attention pooling."""
    
    def __init__(self, **kwargs):
        PoolingStrategy.__init__(self)
        nn.Module.__init__(self)
        self.attention_layer = None
        # Accept but ignore any extra kwargs for compatibility
    
    def setup(self, hidden_size: int):
        self.attention_layer = nn.Linear(hidden_size, 1)
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        if self.attention_layer is None:
            raise RuntimeError("AttentionPooling requires setup() to be called first")
        
        # Compute attention weights
        attention_weights = self.attention_layer(last_hidden_state).squeeze(-1)  # [batch, seq_len]
        
        # Apply mask to attention weights
        attention_weights = attention_weights.masked_fill(attention_mask == 0, -1e9)
        
        # Softmax to get normalized weights
        attention_weights = torch.softmax(attention_weights, dim=-1)  # [batch, seq_len]
        
        # Apply attention weights
        pooled_output = torch.sum(last_hidden_state * attention_weights.unsqueeze(-1), dim=1)
        
        return pooled_output


class WeightedAveragePooling(PoolingStrategy, nn.Module):
    """Weighted average pooling with learnable position weights."""
    
    def __init__(self, max_seq_length: int = 512, **kwargs):
        PoolingStrategy.__init__(self)
        nn.Module.__init__(self)
        self.max_seq_length = max_seq_length
        self.position_weights = None
        # Accept but ignore any other extra kwargs for compatibility
    
    def setup(self, hidden_size: int):
        # Learnable weights for each position
        self.position_weights = nn.Parameter(torch.ones(self.max_seq_length))
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        if self.position_weights is None:
            raise RuntimeError("WeightedAveragePooling requires setup() to be called first")
        
        batch_size, seq_len, hidden_size = last_hidden_state.shape
        
        # Get position weights for current sequence length
        pos_weights = self.position_weights[:seq_len].unsqueeze(0)  # [1, seq_len]
        
        # Apply attention mask and position weights
        weights = pos_weights * attention_mask.float()  # [batch, seq_len]
        weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)  # Normalize
        
        # Apply weights
        pooled_output = torch.sum(last_hidden_state * weights.unsqueeze(-1), dim=1)
        
        return pooled_output


class MultiPooling(PoolingStrategy, nn.Module):
    """Combine multiple pooling strategies."""
    
    def __init__(self, strategies: List[str], combination: str = 'concat', **kwargs):
        """
        Args:
            strategies: List of pooling strategy names
            combination: How to combine ('concat', 'mean', 'max', 'learned')
        """
        PoolingStrategy.__init__(self)
        nn.Module.__init__(self)
        self.strategies = nn.ModuleList([get_pooling_strategy(s, **kwargs) for s in strategies])
        self.combination = combination
        self.combination_layer = None
        # Accept but ignore any other extra kwargs for compatibility
    
    def setup(self, hidden_size: int):
        # Setup individual strategies
        for strategy in self.strategies:
            strategy.setup(hidden_size)
        
        # Setup combination layer if needed
        if self.combination == 'learned':
            total_dim = len(self.strategies) * hidden_size
            self.combination_layer = nn.Linear(total_dim, hidden_size)
    
    def __call__(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        pooled_outputs = []
        
        for strategy in self.strategies:
            pooled = strategy(last_hidden_state, attention_mask)
            pooled_outputs.append(pooled)
        
        if self.combination == 'concat':
            return torch.cat(pooled_outputs, dim=-1)
        elif self.combination == 'mean':
            return torch.stack(pooled_outputs, dim=0).mean(dim=0)
        elif self.combination == 'max':
            return torch.stack(pooled_outputs, dim=0).max(dim=0)[0]
        elif self.combination == 'learned':
            concatenated = torch.cat(pooled_outputs, dim=-1)
            return self.combination_layer(concatenated)
        else:
            raise ValueError(f"Unknown combination method: {self.combination}")


def get_pooling_strategy(strategy_name: str, **kwargs) -> PoolingStrategy:
    """Factory function to get pooling strategy by name."""
    
    strategy_map = {
        'mean': MeanPooling,
        'max': MaxPooling,
        'sum': SumPooling,
        'last': LastTokenPooling,
        'first': FirstTokenPooling,
        'cls': FirstTokenPooling,  # Alias for first
        'attention': AttentionPooling,
        'weighted_average': WeightedAveragePooling,
    }
    
    if strategy_name.startswith('multi:'):
        # Handle multi-pooling: "multi:mean,max,attention"
        strategies = strategy_name.split(':')[1].split(',')
        combination = kwargs.get('combination', 'concat')
        return MultiPooling(strategies, combination, **kwargs)
    
    if strategy_name not in strategy_map:
        available = list(strategy_map.keys()) + ['multi:strategy1,strategy2,...']
        raise ValueError(f"Unknown pooling strategy: {strategy_name}. Available: {available}")
    
    strategy_class = strategy_map[strategy_name]
    
    # Filter kwargs to only pass relevant ones to each strategy
    if strategy_name == 'weighted_average':
        filtered_kwargs = {k: v for k, v in kwargs.items() if k in ['max_seq_length']}
    elif strategy_name == 'multi' or strategy_name.startswith('multi:'):
        filtered_kwargs = kwargs  # MultiPooling handles filtering internally
    else:
        # For simple pooling strategies, no specific kwargs needed
        filtered_kwargs = {}
    
    return strategy_class(**filtered_kwargs)


# class GenericSequenceClassifier(nn.Module):
#     """
#     Generic sequence classifier with configurable pooling and classification head.
#     """
    
#     def __init__(
#         self,
#         model_name: str,
#         num_labels: int,
#         pooling_strategy: Union[str, PoolingStrategy] = 'mean',
#         hidden_dims: Optional[List[int]] = None,
#         dropout_rate: float = 0.1,
#         activation: str = 'relu',
#         hf_token: Optional[str] = None,
#         **pooling_kwargs
#     ):
#         super().__init__()
        
#         # Load the base model (without classification head)
#         self.base_model = AutoModel.from_pretrained(model_name, token=hf_token)
#         self.config = AutoConfig.from_pretrained(model_name, token=hf_token)
        
#         # Update config for classification
#         self.config.num_labels = num_labels
        
#         # Setup pooling strategy
#         if isinstance(pooling_strategy, str):
#             self.pooling = get_pooling_strategy(pooling_strategy, **pooling_kwargs)
#         else:
#             self.pooling = pooling_strategy
        
#         # Setup pooling strategy (may add learnable parameters)
#         hidden_size = self.base_model.config.hidden_size
#         self.pooling.setup(hidden_size)
        
#         # Register pooling strategy as a submodule if it's a nn.Module
#         if isinstance(self.pooling, nn.Module):
#             self.add_module('pooling_strategy', self.pooling)
        
#         # Determine input size for classifier
#         if hasattr(self.pooling, 'strategies') and self.pooling.combination == 'concat':
#             # Multi-pooling with concatenation
#             classifier_input_size = len(self.pooling.strategies) * hidden_size
#         else:
#             classifier_input_size = hidden_size
        
#         # Build classification head
#         self.classifier = self._build_classifier(
#             classifier_input_size, num_labels, hidden_dims, dropout_rate, activation
#         )
        
#         # Dropout for pooled representation
#         self.dropout = nn.Dropout(dropout_rate)
    
#     def _build_classifier(
#         self, 
#         input_size: int, 
#         num_labels: int, 
#         hidden_dims: Optional[List[int]], 
#         dropout_rate: float,
#         activation: str
#     ) -> nn.Module:
#         """Build the classification head."""
        
#         # Activation function mapping
#         activation_map = {
#             'relu': nn.ReLU,
#             'gelu': nn.GELU,
#             'tanh': nn.Tanh,
#             'leaky_relu': nn.LeakyReLU,
#             'swish': nn.SiLU,
#             'silu': nn.SiLU,
#         }
        
#         if activation not in activation_map:
#             raise ValueError(f"Unknown activation: {activation}. Available: {list(activation_map.keys())}")
        
#         activation_fn = activation_map[activation]
        
#         if hidden_dims is None or len(hidden_dims) == 0:
#             # Simple single layer classifier
#             return nn.Linear(input_size, num_labels)
#         else:
#             # Multi-layer classifier
#             layers = []
#             prev_dim = input_size
            
#             for dim in hidden_dims:
#                 layers.extend([
#                     nn.Linear(prev_dim, dim),
#                     activation_fn(),
#                     nn.Dropout(dropout_rate)
#                 ])
#                 prev_dim = dim
            
#             # Final classification layer
#             layers.append(nn.Linear(prev_dim, num_labels))
            
#             return nn.Sequential(*layers)
    
#     def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
#         # Filter kwargs to only include those expected by the base model
#         # Common base model forward arguments
#         base_model_kwargs = {}
#         expected_args = {
#             'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 
#             'output_attentions', 'output_hidden_states', 'return_dict'
#         }
        
#         for key, value in kwargs.items():
#             if key in expected_args:
#                 base_model_kwargs[key] = value
        
#         # Get outputs from base model with filtered kwargs
#         outputs = self.base_model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             **base_model_kwargs
#         )
        
#         # Apply pooling strategy
#         pooled_output = self.pooling(outputs.last_hidden_state, attention_mask)
        
#         # Apply dropout
#         pooled_output = self.dropout(pooled_output)
        
#         # Get logits from classifier
#         logits = self.classifier(pooled_output)
        
#         loss = None
#         if labels is not None:
#             loss_fct = nn.CrossEntropyLoss()
#             loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        
#         return SequenceClassifierOutput(
#             loss=loss,
#             logits=logits,
#             hidden_states=outputs.hidden_states,
#             attentions=outputs.attentions,
#         )
    
#     def resize_token_embeddings(self, new_num_tokens):
#         """Resize token embeddings if needed."""
#         self.base_model.resize_token_embeddings(new_num_tokens)
    
#     def get_pooling_output_size(self):
#         """Get the output size of the pooling layer."""
#         hidden_size = self.base_model.config.hidden_size
#         if hasattr(self.pooling, 'strategies') and self.pooling.combination == 'concat':
#             return len(self.pooling.strategies) * hidden_size
#         return hidden_size