import json
import os
from typing import Optional, List, Union, Any, Dict
import torch
import torch.nn as nn
from transformers import (
    PreTrainedModel, 
    AutoModel, 
    AutoConfig,
    PretrainedConfig
)
from transformers.modeling_outputs import SequenceClassifierOutput


class GenericSequenceClassifierConfig(PretrainedConfig):
    """
    Configuration class for GenericSequenceClassifier.
    """
    model_type = "generic_sequence_classifier"
    
    def __init__(
        self,
        base_model_name: str = None,
        num_labels: int = 2,
        pooling_strategy: str = 'mean',
        pooling_kwargs: Dict[str, Any] = None,
        hidden_dims: Optional[List[int]] = None,
        dropout_rate: float = 0.1,
        activation: str = 'relu',
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.pooling_strategy = pooling_strategy
        self.pooling_kwargs = pooling_kwargs or {}
        self.hidden_dims = hidden_dims
        self.dropout_rate = dropout_rate
        self.activation = activation


class GenericSequenceClassifier(PreTrainedModel):
    """
    Generic sequence classifier with configurable pooling and classification head.
    Inherits from PreTrainedModel for HuggingFace ecosystem compatibility.
    """
    
    config_class = GenericSequenceClassifierConfig
    base_model_prefix = "transformer"
    supports_gradient_checkpointing = True
    
    def __init__(
        self,
        config: GenericSequenceClassifierConfig = None,
        model_name: str = None,
        num_labels: int = 2,
        pooling_strategy: Union[str, 'PoolingStrategy'] = 'mean',
        hidden_dims: Optional[List[int]] = None,
        dropout_rate: float = 0.1,
        activation: str = 'relu',
        hf_token: Optional[str] = None,
        **pooling_kwargs
    ):
        # Handle both config-based and direct parameter initialization
        if config is None:
            if model_name is None or num_labels is None:
                raise ValueError("Either provide config or both model_name and num_labels")
            
            # Create config from parameters
            base_config = AutoConfig.from_pretrained(model_name, token=hf_token)
            
            # Automatically get parameter names from GenericSequenceClassifierConfig.__init__
            import inspect
            config_signature = inspect.signature(GenericSequenceClassifierConfig.__init__)
            config_param_names = set(config_signature.parameters.keys()) - {'self', 'kwargs'}
            
            # Filter out parameters that conflict with GenericSequenceClassifierConfig parameters
            base_config_dict = base_config.to_dict()
            filtered_base_config = {k: v for k, v in base_config_dict.items() 
                                  if k not in config_param_names}
            
            config = GenericSequenceClassifierConfig(
                base_model_name=model_name,
                num_labels=num_labels,
                pooling_strategy=pooling_strategy if isinstance(pooling_strategy, str) else str(pooling_strategy),
                pooling_kwargs=pooling_kwargs,
                hidden_dims=hidden_dims,
                dropout_rate=dropout_rate,
                activation=activation,
                **filtered_base_config
            )
        
        super().__init__(config)
        
        # Store initialization parameters for saving/loading
        self.hf_token = hf_token
        self._pooling_strategy_obj = pooling_strategy if not isinstance(pooling_strategy, str) else None
        
        # Load the base model (without classification head)
        # Load base model config and ensure it has the classification attributes
        base_model_config = AutoConfig.from_pretrained(
            config.base_model_name, 
            token=hf_token
        )
        
        # Force add classification-specific attributes to the config
        # even if they don't exist in the original config
        print(f"config.num_labels: {config.num_labels}")
        base_model_config.num_labels = config.num_labels
        base_model_config.id2label = {i: f"LABEL_{i}" for i in range(config.num_labels)}
        base_model_config.label2id = {f"LABEL_{i}": i for i in range(config.num_labels)}
        
        # Store the updated config attributes
        config.hidden_size = base_model_config.hidden_size
        
        # Copy important attributes from base model config that TextAttack might need
        important_attrs = [
            'max_position_embeddings', 'vocab_size', 'model_type', 
            'architectures', '_name_or_path', 'torch_dtype'
        ]
        for attr in important_attrs:
            if hasattr(base_model_config, attr):
                setattr(config, attr, getattr(base_model_config, attr))
        
        # Fallback for max_position_embeddings if not present (needed for A2T attacks)
        if not hasattr(config, 'max_position_embeddings'):
            # Common default values for different model types
            fallback_max_pos = {
                'gpt2': 1024,
                'bert': 512,
                'roberta': 512,
                'distilbert': 512,
                'albert': 512,
                'electra': 512,
            }
            model_type = getattr(config, 'model_type', 'bert').lower()
            for key in fallback_max_pos:
                if key in model_type:
                    config.max_position_embeddings = fallback_max_pos[key]
                    print(f"Setting fallback max_position_embeddings to {config.max_position_embeddings} for model type {model_type}")
                    break
            else:
                config.max_position_embeddings = 512  # Safe default
                print(f"Setting default max_position_embeddings to 512 for unknown model type {model_type}")
        
        # Debug: print the config to verify
        print(f"Updated base_model_config.num_labels: {base_model_config.num_labels}")
        print(f"Updated base_model_config.id2label length: {len(base_model_config.id2label)}")
        if hasattr(config, 'max_position_embeddings'):
            print(f"Config max_position_embeddings: {config.max_position_embeddings}")
        
        self.transformer = AutoModel.from_pretrained(
            config.base_model_name, 
            config=base_model_config,
            token=hf_token,
            torch_dtype='auto'
        )
        
        # Setup pooling strategy
        if isinstance(pooling_strategy, str):
            self.pooling = get_pooling_strategy(pooling_strategy, **config.pooling_kwargs)
        else:
            self.pooling = pooling_strategy
        
        # Setup pooling strategy (may add learnable parameters)
        hidden_size = self.transformer.config.hidden_size
        self.pooling.setup(hidden_size)
        
        # Register pooling strategy as a submodule if it's a nn.Module
        if isinstance(self.pooling, nn.Module):
            self.add_module('pooling_strategy', self.pooling)
        
        # Determine input size for classifier
        if hasattr(self.pooling, 'strategies') and self.pooling.combination == 'concat':
            # Multi-pooling with concatenation
            classifier_input_size = len(self.pooling.strategies) * hidden_size
        else:
            classifier_input_size = hidden_size
        
        # Build classification head
        self.classifier = self._build_classifier(
            classifier_input_size, 
            config.num_labels, 
            config.hidden_dims, 
            config.dropout_rate, 
            config.activation
        )
        
        # Dropout for pooled representation
        self.dropout = nn.Dropout(config.dropout_rate)
    
    def _build_classifier(
        self, 
        input_size: int, 
        num_labels: int, 
        hidden_dims: Optional[List[int]], 
        dropout_rate: float,
        activation: str
    ) -> nn.Module:
        """Build the classification head."""
        
        # Activation function mapping
        activation_map = {
            'relu': nn.ReLU,
            'gelu': nn.GELU,
            'tanh': nn.Tanh,
            'leaky_relu': nn.LeakyReLU,
            'swish': nn.SiLU,
            'silu': nn.SiLU,
        }
        
        if activation not in activation_map:
            raise ValueError(f"Unknown activation: {activation}. Available: {list(activation_map.keys())}")
        
        activation_fn = activation_map[activation]
        
        if hidden_dims is None or len(hidden_dims) == 0:
            # Simple single layer classifier
            return nn.Linear(input_size, num_labels)
        else:
            # Multi-layer classifier
            layers = []
            prev_dim = input_size
            
            for dim in hidden_dims:
                layers.extend([
                    nn.Linear(prev_dim, dim),
                    activation_fn(),
                    nn.Dropout(dropout_rate)
                ])
                prev_dim = dim
            
            # Final classification layer
            layers.append(nn.Linear(prev_dim, num_labels))
            
            return nn.Sequential(*layers)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Filter kwargs to only include those expected by the base model
        # Common base model forward arguments
        base_model_kwargs = {}
        expected_args = {
            'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 
            'output_attentions', 'output_hidden_states', 'return_dict'
        }
        
        for key, value in kwargs.items():
            if key in expected_args:
                base_model_kwargs[key] = value
        
        # Get outputs from base model with filtered kwargs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **base_model_kwargs
        )
        
        # Apply pooling strategy
        pooled_output = self.pooling(outputs.last_hidden_state, attention_mask)
        
        # Apply dropout
        pooled_output = self.dropout(pooled_output)
        
        # Get logits from classifier
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    def resize_token_embeddings(self, new_num_tokens):
        """Resize token embeddings if needed."""
        return self.transformer.resize_token_embeddings(new_num_tokens)
    
    def get_pooling_output_size(self):
        """Get the output size of the pooling layer."""
        hidden_size = self.transformer.config.hidden_size
        if hasattr(self.pooling, 'strategies') and self.pooling.combination == 'concat':
            return len(self.pooling.strategies) * hidden_size
        return hidden_size
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        num_labels=2,
        *model_args,
        config=None,
        cache_dir=None,
        ignore_mismatched_sizes=False,
        force_download=False,
        local_files_only=False,
        token=None,
        revision="main",
        use_safetensors=None,
        **kwargs
    ):
        """
        Load a pretrained model from a local directory or HuggingFace Hub.
        """
        # Load config
        if config is None:
            config = cls.config_class.from_pretrained(
                pretrained_model_name_or_path,
                num_labels=num_labels,
                cache_dir=cache_dir,
                force_download=force_download,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                **kwargs
            )
        
        # Initialize model with config
        model = cls(config=config, hf_token=token)
        
        print(f"model.config.vocab_size: {model.config.vocab_size}")

        model.resize_token_embeddings(model.config.vocab_size)
        # Load state dict
        state_dict = None
        
        # Try to load from local directory first
        if os.path.isdir(pretrained_model_name_or_path):
            # Look for pytorch_model.bin or model.safetensors
            model_file = None
            if use_safetensors is not False:
                safetensors_file = os.path.join(pretrained_model_name_or_path, "model.safetensors")
                if os.path.isfile(safetensors_file):
                    model_file = safetensors_file
            
            if model_file is None:
                pytorch_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
                if os.path.isfile(pytorch_file):
                    model_file = pytorch_file
            
            if model_file:
                if model_file.endswith('.safetensors'):
                    from safetensors.torch import load_file
                    state_dict = load_file(model_file)
                else:
                    state_dict = torch.load(model_file, map_location="cpu")
        
        if state_dict is not None:
            # Check for vocabulary size mismatch and handle it
            # Different models use different embedding keys
            embedding_keys = ['transformer.wte.weight', 'transformer.embed_tokens.weight', 'embeddings.word_embeddings.weight']
            embedding_key_found = None
            
            for key in embedding_keys:
                if key in state_dict:
                    embedding_key_found = key
                    break
            
            if embedding_key_found:
                saved_vocab_size = state_dict[embedding_key_found].shape[0]
                
                # Get current model's embedding size based on the model type
                current_vocab_size = None
                if hasattr(model, 'transformer') and hasattr(model.transformer, 'wte'):
                    current_vocab_size = model.transformer.wte.weight.shape[0]
                elif hasattr(model, 'transformer') and hasattr(model.transformer, 'embed_tokens'):
                    current_vocab_size = model.transformer.embed_tokens.weight.shape[0]
                elif hasattr(model, 'embeddings') and hasattr(model.embeddings, 'word_embeddings'):
                    current_vocab_size = model.embeddings.word_embeddings.weight.shape[0]
                
                if current_vocab_size and saved_vocab_size != current_vocab_size:
                    print(f"Vocabulary size mismatch detected: saved={saved_vocab_size}, current={current_vocab_size}")
                    print(f"Resizing token embeddings to match saved model...")
                    model.resize_token_embeddings(saved_vocab_size)
                    # Update config to reflect the actual vocabulary size
                    model.config.vocab_size = saved_vocab_size
            
            # Load the state dict into the model
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            
            if len(missing_keys) > 0:
                print(f"Missing keys when loading model: {missing_keys}")
            if len(unexpected_keys) > 0:
                print(f"Unexpected keys when loading model: {unexpected_keys}")
        
        return model
    
    def save_pretrained(
        self,
        save_directory,
        is_main_process=True,
        state_dict=None,
        save_function=torch.save,
        push_to_hub=False,
        max_shard_size="10GB",
        safe_serialization=True,
        variant=None,
        token=None,
        save_peft_format=True,
        **kwargs
    ):
        """
        Save the model and its configuration to a directory.
        """
        if os.path.isfile(save_directory):
            print(f"Provided path ({save_directory}) should be a directory, not a file")
            return
        
        os.makedirs(save_directory, exist_ok=True)
        
        # Save config
        self.config.save_pretrained(save_directory)
        
        # Save model weights
        if state_dict is None:
            state_dict = self.state_dict()
        
        # Save as safetensors or pytorch_model.bin
        if safe_serialization:
            try:
                from safetensors.torch import save_file
                save_file(state_dict, os.path.join(save_directory, "model.safetensors"))
            except ImportError:
                print("safetensors not available, saving as pytorch_model.bin")
                save_function(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
        else:
            save_function(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
        
        print(f"Model saved in {save_directory}")

    def get_input_embeddings(self):
        return self.transformer.get_input_embeddings()

    def set_input_embeddings(self, new_embeds):
        return self.transformer.set_input_embeddings(new_embeds)

    def _set_gradient_checkpointing(self, module, value: bool = False):
        if hasattr(self.transformer, "gradient_checkpointing_enable") and value:
            self.transformer.gradient_checkpointing_enable()
        elif hasattr(self.transformer, "gradient_checkpointing_disable") and not value:
            self.transformer.gradient_checkpointing_disable()
        if hasattr(self.transformer, "gradient_checkpointing"):
            self.transformer.gradient_checkpointing = value


# You'll need to implement or import these functions
def get_pooling_strategy(strategy_name: str, **kwargs):
    """
    Factory function to create pooling strategies.
    You'll need to implement this based on your pooling strategies.
    """
    # Remove any kwargs that aren't used by simple pooling strategies
    # These are likely used by multi-pooling strategies
    simple_kwargs = {k: v for k, v in kwargs.items() if k not in ['combination']}
    
    if strategy_name == 'mean':
        return MeanPoolingStrategy(**simple_kwargs)
    elif strategy_name == 'max':
        return MaxPoolingStrategy(**simple_kwargs)
    # elif strategy_name == 'multi' or 'combination' in kwargs:
    #     # Handle multi-pooling strategy
    #     return MultiPoolingStrategy(**kwargs)
    else:
        raise ValueError(f"Unknown pooling strategy: {strategy_name}")


# Example pooling strategy classes (you'll need to implement these)
class MeanPoolingStrategy(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # Accept and ignore any extra kwargs for compatibility
    
    def setup(self, hidden_size):
        pass
    
    def forward(self, hidden_states, attention_mask):
        # Implement mean pooling with attention mask
        if attention_mask is not None:
            # Mask out padding tokens
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            return sum_embeddings / sum_mask
        else:
            return hidden_states.mean(dim=1)


class MaxPoolingStrategy(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # Accept and ignore any extra kwargs for compatibility
    
    def setup(self, hidden_size):
        pass
    
    def forward(self, hidden_states, attention_mask):
        # Implement max pooling with attention mask
        if attention_mask is not None:
            # Set padding tokens to very negative values before max pooling
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size())
            hidden_states = hidden_states.clone()
            hidden_states[mask_expanded == 0] = -1e9
        return hidden_states.max(dim=1)[0]


# class MultiPoolingStrategy(nn.Module):
#     def __init__(self, strategies=None, combination='concat', **kwargs):
#         super().__init__()
#         self.combination = combination
#         self.strategies = strategies or ['mean', 'max']
        
#         # Initialize individual pooling strategies
#         self.pooling_modules = nn.ModuleList()
#         for strategy in self.strategies:
#             if strategy == 'mean':
#                 self.pooling_modules.append(MeanPoolingStrategy())
#             elif strategy == 'max':
#                 self.pooling_modules.append(MaxPoolingStrategy())
#             # Add more strategies as needed
    
#     def setup(self, hidden_size):
#         for pooling_module in self.pooling_modules:
#             pooling_module.setup(hidden_size)
    
#     def forward(self, hidden_states, attention_mask):
#         pooled_outputs = []
#         for pooling_module in self.pooling_modules:
#             pooled_outputs.append(pooling_module(hidden_states, attention_mask))
        
#         if self.combination == 'concat':
#             return torch.cat(pooled_outputs, dim=-1)
#         elif self.combination == 'mean':
#             return torch.stack(pooled_outputs).mean(dim=0)
#         elif self.combination == 'sum':
#             return torch.stack(pooled_outputs).sum(dim=0)
#         else:
#             raise ValueError(f"Unknown combination method: {self.combination}")