from typing import Tuple
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig
)
from src.loggers.setup_logging import setup_logging
from src.model_loading.common.config.model_config import ModelConfig
from src.model_loading.loaders.interface import ModelLoaderInterface
from src.model_loading.common.enums import BitPrecision

logger = setup_logging()

class BitsAndBytesModelLoader(ModelLoaderInterface):
    """Loader for BitsAndBytes quantized models supporting both local and HuggingFace loading."""
    
    def load_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a model based on configuration using either local or HF loading strategy."""
        logger.info(f"Loading BitsAndBytes model with identifier: {config.identifier}")
        if config.identifier.is_local:
            return self._load_local_model(
                config=config,
                compile_mode=compile_mode,
                fullgraph=fullgraph,
                dynamic=dynamic
            )
            
        return self._load_hf_model(
            config=config,
            compile_mode=compile_mode,
            fullgraph=fullgraph,
            dynamic=dynamic
        )
    
    def _load_local_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a local model using standard loading strategy."""
        model = AutoModelForCausalLM.from_pretrained(
            config.model_path,
            torch_dtype="auto",
            device_map=config.device,
            max_memory=config.max_memory,
            cache_dir=config.cache_dir,
            trust_remote_code=config.trust_remote_code
        )
        
        # Compile model if CUDA is available
        if torch.cuda.is_available() and config.apply_compile:
            model = torch.compile(
                model,
                mode=compile_mode,
                fullgraph=fullgraph,
                dynamic=dynamic
            )
        
        tokenizer = self._load_tokenizer(config)
        return model, tokenizer
    
    def _load_hf_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a HuggingFace model using BitsAndBytes quantization."""
        bits = config.identifier.bits
        if not bits or bits not in [BitPrecision.INT4, BitPrecision.INT8]:
            raise ValueError(f"BitsAndBytes requires 4 or 8 bit precision, got: {bits}")
        
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=(bits == BitPrecision.INT4),
            load_in_8bit=(bits == BitPrecision.INT8)
        )
        
        # Make a copy of max_memory with integer keys for GPU devices
        max_memory = None
        if config.max_memory is not None:
            max_memory = {}
            for key, value in dict(config.max_memory).items():
                # Convert string device IDs to integers if they're numeric
                if isinstance(key, str) and key.isdigit():
                    max_memory[int(key)] = value
                else:
                    max_memory[key] = value
        
        # Check if we should use multi-GPU loading
        is_multi_gpu = (
            max_memory is not None and
            isinstance(max_memory, dict) and
            len(max_memory) > 1
        )
        
        # Multi-GPU loading path
        if is_multi_gpu:
            logger.info(f"Loading HF BitsAndBytes model {config.model_path} across multiple devices using max_memory configuration")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                quantization_config=bnb_config,
                torch_dtype="auto",
                device_map="auto",  # Use automatic device mapping for multi-GPU
                max_memory=max_memory,
                cache_dir=config.cache_dir,
                trust_remote_code=config.trust_remote_code
            )
            # Skip compilation for multi-GPU models as it's not compatible
        else:
            logger.info(f"Loading HF BitsAndBytes model {config.model_path} on single device: {config.device}")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                quantization_config=bnb_config,
                torch_dtype="auto",
                device_map=config.device,
                max_memory=max_memory,
                cache_dir=config.cache_dir,
                trust_remote_code=config.trust_remote_code
            )
            
            # Compile model if CUDA is available (only for single-GPU case)
            if torch.cuda.is_available() and config.apply_compile:
                model = torch.compile(
                    model,
                    mode=compile_mode,
                    fullgraph=fullgraph,
                    dynamic=dynamic
                )
        
        tokenizer = self._load_tokenizer(config)
        return model, tokenizer
    
    def _load_tokenizer(self, config: ModelConfig) -> AutoTokenizer:
        """Load tokenizer with common configuration."""
        return AutoTokenizer.from_pretrained(
            config.tokenizer_path or config.model_path,
            device_map=config.device,
            cache_dir=config.cache_dir,
            trust_remote_code=config.trust_remote_code
        )