import os
from typing import Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.loggers.setup_logging import setup_logging
from src.model_loading.common.config.model_config import ModelConfig
from src.model_loading.common.enums.model_enums import QuantizationMethod
from src.model_loading.common.models.stats import ModelStats
from src.model_loading.factory.loader_factory import ModelLoaderFactory
from src.model_loading.registry.registry import ModelRegistry


logger = setup_logging()


class ModelManager:
    """High-level manager for model loading and configuration"""
    
    def __init__(self):
        self.registry = ModelRegistry()
    
    def _prepare_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
        """Prepare tokenizer with common configurations"""
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.padding_side = "left"
        if tokenizer.model_max_length > 1e6:
            logger.warning(f"Tokenizer model max length reduced from {tokenizer.model_max_length} to 2048")
            tokenizer.model_max_length = 2048
        return tokenizer
    
    def _collect_model_stats(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> ModelStats:
        """Collect and return model statistics"""
        return ModelStats(
            cache_dir=getattr(model, 'cache_dir', None),
            model_max_length=tokenizer.model_max_length,
            dtype=getattr(model, 'dtype', None),
            device=getattr(model, 'device', None),
            num_parameters=getattr(model, 'num_parameters', 0),
            memory_footprint=getattr(model, 'memory_footprint', 0) / (1024 ** 3),
            vocab_size=tokenizer.vocab_size,
            pad_token_id=tokenizer.eos_token_id,
            special_tokens=tokenizer.special_tokens_map
        )
    
    def load_model(self, config: ModelConfig) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load model and tokenizer based on configuration"""
        try:
            logger.info(f"Loading model: {config.identifier}")
            
            paths = self.registry.get_model_paths(config.identifier)
            if not paths:
                raise ValueError(f"No paths found for model: {config.identifier}")
            
            if config.identifier.is_local and not os.path.exists(paths.model_path):
                raise OSError(f"Local model path does not exist: {paths.model_path}")
            
            loader = ModelLoaderFactory.create_loader(quant_method=config.identifier.quantization or QuantizationMethod.NONE)
            
            # Update config with paths from registry
            config.model_path = paths.model_path
            config.tokenizer_path = paths.tokenizer_path
            
            model, tokenizer = loader.load_model(config)
            model.NAME = str(config.identifier)
            tokenizer = self._prepare_tokenizer(tokenizer)
            
            stats = self._collect_model_stats(model, tokenizer)
            self._log_model_info(str(config.identifier), stats)
            
            return model, tokenizer
            
        except Exception as e:
            logger.error(f"Error loading model {config.identifier}: {str(e)}")
            raise
    
    def _log_model_info(self, model_name: str, stats: ModelStats):
        """Log model information and statistics"""
        logger.info(f"Successfully loaded model: {model_name}")
        logger.info("Model configuration:")
        logger.info(f"- Cache directory: {stats.cache_dir}")
        logger.info(f"- Model max length: {stats.model_max_length}")
        logger.info(f"- Model dtype: {stats.dtype}")
        logger.info(f"- Model device: {stats.device}")
        logger.info(f"- Model parameters: {stats.num_parameters}")
        logger.info(f"- Memory footprint: {stats.memory_footprint:.2f} GB")
        logger.info(f"- Vocabulary size: {stats.vocab_size}")
        logger.info(f"- Padding token ID: {stats.pad_token_id}")
        logger.info(f"- Special tokens: {stats.special_tokens}")
