from pathlib import Path
from typing import List, Optional, Any, Union, Dict, Tuple
from argparse import Namespace
from config_loader import load_saved_config
from model_adapter import ModelArchitectureAdapter

def namespace_to_dict(namespace: Namespace) -> dict:
    """Convert argparse Namespace to dictionary"""
    return vars(namespace) if namespace else {}

def get_dirs_to_process(base_dir: Path, model_dirs: Optional[Union[List[str], Namespace]] = None, 
                       latents_prefix: bool = False) -> List[Path]:
    """
    Get list of directories to process based on command line arguments
    
    Args:
        base_dir: Base directory to look for model directories
        model_dirs: Optional list of model directories or Namespace containing model_dirs
        latents_prefix: If True, prepend 'latents_' to directory names
    """
    # Handle both Namespace and direct list input
    if isinstance(model_dirs, Namespace):
        model_dirs = getattr(model_dirs, 'model_dirs', None)
    
    if model_dirs:
        if latents_prefix:
            return [base_dir / f"latents_{d}" for d in model_dirs]
        return [base_dir / d for d in model_dirs]
    
    # Process all directories if no specific ones provided
    return [d for d in base_dir.iterdir() if d.is_dir()]

def detect_model_architecture(config_path: Path) -> Optional[str]:
    """
    Detect model architecture from saved config.
    
    Args:
        config_path: Path to the saved config file
        
    Returns:
        Model architecture type or None if not determined
    """
    if not config_path.exists():
        return None
        
    try:
        with open(config_path, 'r') as f:
            import json
            config_data = json.load(f)
            model_name = config_data.get('model_name', '').lower()
            
            if 'neox' in model_name:
                return 'gpt_neox'
            elif 'gemma' in model_name:
                return 'gemma'
            elif 'llama' in model_name:
                return 'llama'
    except:
        pass
    return None

def validate_layer_paths(model_dir: Path, adapter: ModelArchitectureAdapter) -> bool:
    """
    Validate that layer directories exist with correct naming pattern
    
    Args:
        model_dir: Directory containing model files
        adapter: Model architecture adapter instance
        
    Returns:
        True if directory structure is valid
    """
    # Check first layer directory exists
    first_layer_prefix = adapter.get_layer_prefix(0).lstrip('.')
    first_layer_path = model_dir / first_layer_prefix
    
    return first_layer_path.exists()

def fix_directory_structure(model_dir: Path, adapter: ModelArchitectureAdapter) -> None:
    """
    Fix directory structure if needed for compatibility
    
    Args:
        model_dir: Directory containing model files
        adapter: Model architecture adapter instance
    """
    import shutil
    
    for i in range(adapter.num_layers()):
        old_prefix = f"gpt_neox.layers.{i}"
        new_prefix = adapter.get_layer_prefix(i).lstrip('.')
        
        old_path = model_dir / old_prefix
        new_path = model_dir / new_prefix
        
        if old_path.exists() and not new_path.exists():
            print(f"Fixing directory structure: {old_prefix} -> {new_prefix}")
            old_path.rename(new_path)

def get_sae_paths(model_dir: Path, adapter: ModelArchitectureAdapter, 
                  layer_stride: int = 1) -> List[Tuple[int, Path]]:
    """
    Get paths to SAE weights with correct layer indexing
    
    Args:
        model_dir: Directory containing model files
        adapter: Model architecture adapter instance
        layer_stride: Stride between layers
        
    Returns:
        List of tuples (layer_idx, path)
    """
    sae_paths = []
    for layer in range(0, adapter.num_layers(), layer_stride):
        layer_prefix = adapter.get_layer_prefix(layer).lstrip('.')
        sae_path = model_dir / layer_prefix / "sae.pt"
        if sae_path.exists():
            sae_paths.append((layer, sae_path))
    return sae_paths

def load_model_config(model_dir: Path, current_config: Any, 
                     current_overrides: Union[dict, Namespace]) -> Tuple[Any, ModelArchitectureAdapter]:
    """
    Load configuration from a model directory while preserving overrides and creating adapter
    
    Args:
        model_dir: Directory containing the model and its config
        current_config: Current configuration object
        current_overrides: Dictionary or Namespace of command line overrides
        
    Returns:
        Tuple of (updated config, model adapter)
    """
    # Convert Namespace to dict if needed
    if isinstance(current_overrides, Namespace):
        current_overrides = namespace_to_dict(current_overrides)
    
    config_path = model_dir / "config.json"
    if config_path.exists():
        print(f"Loading saved configuration from {config_path}")
        saved_config = load_saved_config(config_path)
        
        # Apply saved config while preserving overrides
        for key, value in vars(saved_config).items():
            if key not in current_overrides:
                setattr(current_config, key, value)
    
    # Initialize model to create adapter
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(
        current_config.model_name,
        torch_dtype=getattr(torch, current_config.torch_dtype)
    )
    adapter = ModelArchitectureAdapter(model)
    
    # Check and fix directory structure if needed
    if not validate_layer_paths(model_dir, adapter):
        fix_directory_structure(model_dir, adapter)
    
    return current_config, adapter

def get_latents_dir(model_dir: Path, latents_prefix: bool = True) -> Path:
    """Get corresponding latents directory for a model directory"""
    parent = model_dir.parent
    model_name = model_dir.name
    if latents_prefix:
        return parent / f"latents_{model_name}"
    return parent / f"latents" / model_name

def get_eval_dir(model_dir: Path) -> Path:
    """Get evaluation directory for a model directory"""
    return model_dir.parent / "eval" / model_dir.name

def cleanup_unused_files(dir_path: Path, adapter: ModelArchitectureAdapter) -> None:
    """
    Clean up unused files while preserving architecture-specific paths
    
    Args:
        dir_path: Directory to clean
        adapter: Model architecture adapter instance
    """
    import shutil
    
    # Only remove files that don't match the current architecture pattern
    for item in dir_path.iterdir():
        if item.is_dir() and not any(
            adapter.get_layer_prefix(i).lstrip('.') in str(item)
            for i in range(adapter.num_layers())
        ):
            shutil.rmtree(item)