"""
Configuration utilities for dictionary learning.
"""
import os
from pathlib import Path
from typing import Dict, Any, Optional


class ConfigManager:
    """
    Manages configuration settings and path handling for dictionary learning.
    """
    
    def __init__(self, base_cfg: Dict[str, Any]):
        """
        Initialize configuration manager.
        
        Args:
            base_cfg: Base configuration dictionary
        """
        self.cfg = base_cfg.copy()
        self._setup_default_paths()
    
    def _setup_default_paths(self):
        """Setup default path configurations."""
        # Get sparsity for dynamic path generation
        sparsity = self.cfg.get('sparsity', 256)
        
        # Set default paths if not provided or if they are None (from argparse)
        defaults = {
            'checkpoint_dir': 'checkpoints',
            'dictionary_dir': f'dictionaries_s{sparsity}',
            'runs_dir': f'runs_s{sparsity}',
            'data_base_dir': '/data/llm/tmp',
        }
        
        for key, default_value in defaults.items():
            if self.cfg.get(key) is None:
                self.cfg[key] = default_value
    
    def get_model_name(self) -> str:
        """
        Generate model name from configuration.
        
        Returns:
            str: Generated model name
        """
        if 'name' in self.cfg:
            return self.cfg['name']
        
        # Generate name from components
        model_path = self.cfg['model_name_or_path'].replace("/", "_")
        norm_suffix = '_norm' if self.cfg.get('use_norm', False) else ''
        concat_suffix = f'_concat{self.cfg["concat"]}' if self.cfg.get('concat', 1) > 1 else ''
        
        # Include feature_dim only if available
        feature_suffix = f'_f_{self.cfg["feature_dim"]}' if 'feature_dim' in self.cfg else ''
        
        name = (f'{model_path}{concat_suffix}{norm_suffix}_'
                f'N_{self.cfg["dictionary_size"]}_'
                f's_{self.cfg["sparsity"]}'
                f'{feature_suffix}')
        
        self.cfg['name'] = name
        return name
    
    def get_path(self, path_type: str, filename: Optional[str] = None) -> Path:
        """
        Get a configured path.
        
        Args:
            path_type: Type of path ('checkpoint', 'dictionary', 'runs', 'data_base')
            filename: Optional filename to append
            
        Returns:
            Path: Configured path
        """
        path_map = {
            'checkpoint': self.cfg['checkpoint_dir'],
            'dictionary': self.cfg['dictionary_dir'], 
            'runs': self.cfg['runs_dir'],
            'data_base': self.cfg['data_base_dir'],
        }
        
        if path_type not in path_map:
            raise ValueError(f"Unknown path type: {path_type}")
        
        base_path = Path(path_map[path_type])
        
        if filename:
            return base_path / filename
        return base_path
    
    def get_tensorboard_log_dir(self) -> str:
        """
        Get TensorBoard log directory.
        
        Returns:
            str: Log directory path
        """
        return str(self.get_path('runs') / self.get_model_name())
    
    def ensure_directories(self):
        """Ensure all configured directories exist."""
        for path_type in ['checkpoint', 'dictionary', 'runs']:
            self.get_path(path_type).mkdir(parents=True, exist_ok=True)
    
    def update(self, updates: Dict[str, Any]):
        """
        Update configuration with new values.
        
        Args:
            updates: Dictionary of updates to apply
        """
        self.cfg.update(updates)
        # Re-setup paths if any path-related config changed
        path_keys = {'sparsity', 'checkpoint_dir', 'dictionary_dir', 'runs_dir', 'data_base_dir'}
        if path_keys.intersection(updates.keys()):
            self._setup_default_paths()
    
    def get(self, key: str, default: Any = None) -> Any:
        """
        Get configuration value.
        
        Args:
            key: Configuration key
            default: Default value if key not found
            
        Returns:
            Configuration value
        """
        return self.cfg.get(key, default)
    
    def __getitem__(self, key: str) -> Any:
        """Get configuration value using dict-style access."""
        return self.cfg[key]
    
    def __setitem__(self, key: str, value: Any):
        """Set configuration value using dict-style access."""
        self.cfg[key] = value
        # Re-generate model name if relevant keys change
        if key in ['feature_dim', 'model_name_or_path', 'concat', 'use_norm', 'dictionary_size', 'sparsity']:
            if 'name' in self.cfg:
                del self.cfg['name']  # Force regeneration
    
    def __contains__(self, key: str) -> bool:
        """Check if key exists in configuration."""
        return key in self.cfg
    
    def keys(self):
        """Get configuration keys."""
        return self.cfg.keys()
    
    def items(self):
        """Get configuration items."""
        return self.cfg.items()
    
    def copy(self) -> Dict[str, Any]:
        """Get a copy of the configuration dictionary."""
        return self.cfg.copy()


def setup_training_config(args_dict: Dict[str, Any]) -> ConfigManager:
    """
    Setup training configuration from command line arguments.
    
    Args:
        args_dict: Dictionary of command line arguments
        
    Returns:
        ConfigManager: Configured manager instance
    """
    import torch
    
    cfg_manager = ConfigManager(args_dict)
    
    # Add computed configurations
    cfg_manager['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    return cfg_manager
