from typing import Optional, List, Any
import torch
from transformers import PreTrainedModel

class ModelArchitectureAdapter:
    """Adapter for different model architectures to provide consistent access patterns"""
    
    def __init__(self, model: PreTrainedModel):
        self.model = model
        self.model_type = self._detect_model_type()
        
    def _detect_model_type(self) -> str:
        """Detect the model architecture type"""
        if hasattr(self.model, 'gpt_neox'):
            return 'gpt_neox'
        elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
            # Gemma pattern
            return 'gemma'
        elif hasattr(self.model, 'model') and hasattr(self.model.model, 'decoder'):
            # Llama pattern
            return 'llama'
        else:
            raise ValueError("Unsupported model architecture")
    
    @property
    def layers(self) -> Any:
        """Get the transformer layers in a consistent way"""
        if self.model_type == 'gpt_neox':
            return self.model.gpt_neox.layers
        elif self.model_type == 'gemma':
            return self.model.model.layers
        elif self.model_type == 'llama':
            return self.model.model.decoder.layers
        raise ValueError(f"Unsupported model type: {self.model_type}")

    @property
    def return_module_str(self) -> str:
        """Get the return module string in a consistent way"""
        if self.model_type == 'gpt_neox':
            return 'gpt_neox'
        elif self.model_type == 'gemma':
            return 'model'
        elif self.model_type == 'llama':
            return 'model'
        raise ValueError(f"Unsupported model type: {self.model_type}")
    
    @property
    def embedding_layer(self) -> Any:
        """Get the embedding layer in a consistent way"""
        if self.model_type == 'gpt_neox':
            return self.model.gpt_neox.embed_in
        elif self.model_type == 'gemma':
            return self.model.model.embed_tokens
        elif self.model_type == 'llama':
            return self.model.model.decoder.embed_tokens
        raise ValueError(f"Unsupported model type: {self.model_type}")
    
    @property
    def embedding_dim(self) -> int:
        """Get the embedding dimension in a consistent way"""
        return self.embedding_layer.embedding_dim
    
    def get_layer_prefix(self, layer_idx: int) -> str:
        """Get the correct layer prefix for the model type"""
        if self.model_type == 'gpt_neox':
            return f'.gpt_neox.layers.{layer_idx}'
        elif self.model_type == 'gemma':
            return f'.model.layers.{layer_idx}'
        elif self.model_type == 'llama':
            return f'.model.decoder.layers.{layer_idx}'
        raise ValueError(f"Unsupported model type: {self.model_type}")
    
    def get_hook_pattern(self, layer_idx: int) -> str:
        """Get the correct hook pattern for residual connections"""
        if self.model_type == 'gpt_neox':
            return f"blocks.{layer_idx}.hook_resid_post"
        elif self.model_type == 'gemma':
            return f"blocks.{layer_idx}.hook_resid_post"
        elif self.model_type == 'llama':
            return f"decoder.layers.{layer_idx}.hook_resid_post"
        raise ValueError(f"Unsupported model type: {self.model_type}")
    
    def num_layers(self) -> int:
        """Get the total number of layers"""
        return len(self.layers)