from typing import Tuple
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    HqqConfig
)
from hqq.engine.hf import HQQModelForCausalLM
from hqq.models.hf.base import AutoHQQHFModel
from src.loggers.setup_logging import setup_logging
from src.model_loading.common.config.model_config import ModelConfig
from src.model_loading.common.enums.model_enums import BitPrecision
from src.model_loading.loaders.interface import ModelLoaderInterface

logger = setup_logging()

class HQQModelLoader(ModelLoaderInterface):
    """Loader for HQQ quantized models supporting both local and HuggingFace loading."""
    
    def load_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load model based on configuration type."""
        logger.info(f"Loading HQQ model with identifier: {config.identifier}")
        
        if config.identifier.is_local:
            return self._load_local_model(
                config=config,
                compile_mode=compile_mode,
                fullgraph=fullgraph,
                dynamic=dynamic
            )

        return self._load_hf_model(
            config=config,
            compile_mode=compile_mode,
            fullgraph=fullgraph,
            dynamic=dynamic
        )

    def _load_local_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a locally quantized HQQ model."""
        logger.debug(f"Loading local HQQ model from: {config.model_path}")
        
        try:
            model = HQQModelForCausalLM.from_quantized(
                config.model_path,
                device_map='auto',
                cache_dir=config.cache_dir
            )
            
            # Compile model if CUDA is available
            if torch.cuda.is_available() and config.apply_compile:
                model = torch.compile(
                    model,
                    mode=compile_mode,
                    fullgraph=fullgraph,
                    dynamic=dynamic
                )
        except Exception as e:
            logger.warning(f"Failed to load with HQQModelForCausalLM, attempting AutoHQQHFModel: {e}")
            model = AutoHQQHFModel.from_quantized(
                config.model_path,
                device_map='auto',
                cache_dir=config.cache_dir
            )
            
            # Compile model if CUDA is available
            if torch.cuda.is_available() and config.apply_compile:
                model = torch.compile(
                    model,
                    mode=compile_mode,
                    fullgraph=fullgraph,
                    dynamic=dynamic
                )
            
        tokenizer = self._load_tokenizer(config)
        return model, tokenizer

    def _load_hf_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load an HQQ model directly from HuggingFace."""
        logger.debug(f"Loading HuggingFace HQQ model: {config.model_path}")
        
        # Determine quantization configuration
        bits = self._get_quantization_bits(config)
        quant_config = self._create_quantization_config(bits)
        
        # Make a copy of max_memory with integer keys for GPU devices
        max_memory = None
        if config.max_memory is not None:
            max_memory = {}
            for key, value in dict(config.max_memory).items():
                # Convert string device IDs to integers if they're numeric
                if isinstance(key, str) and key.isdigit():
                    max_memory[int(key)] = value
                else:
                    max_memory[key] = value
        
        # Check if we should use multi-GPU loading
        is_multi_gpu = (
            max_memory is not None and
            isinstance(max_memory, dict) and
            len(max_memory) > 1
        )
        
        # Multi-GPU loading path
        if is_multi_gpu:
            logger.info(f"Loading HQQ model {config.model_path} across multiple devices using max_memory configuration")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                torch_dtype=torch.float16,
                device_map="auto",  # Use automatic device mapping for multi-GPU
                max_memory=max_memory,
                quantization_config=quant_config,
                trust_remote_code=config.trust_remote_code,
                cache_dir=config.cache_dir
            )
            # Skip compilation for multi-GPU models as it's not compatible
        else:
            logger.info(f"Loading HQQ model {config.model_path} on single device: {config.device}")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                torch_dtype=torch.float16,
                device_map=config.device,
                max_memory=max_memory,
                quantization_config=quant_config,
                trust_remote_code=config.trust_remote_code,
                cache_dir=config.cache_dir
            )
            
            # Compile model if CUDA is available (only for single-GPU case)
            if torch.cuda.is_available() and config.apply_compile:
                logger.info(f"Compiling HQQ model with mode: {compile_mode}")
                model = torch.compile(
                    model,
                    mode=compile_mode,
                    fullgraph=fullgraph,
                    dynamic=dynamic
                )
        
        tokenizer = self._load_tokenizer(config)
        return model, tokenizer

    def _load_tokenizer(self, config: ModelConfig) -> AutoTokenizer:
        """Load tokenizer with common configuration."""
        return AutoTokenizer.from_pretrained(
            config.tokenizer_path or config.model_path,
            device_map=config.device,
            cache_dir=config.cache_dir,
            trust_remote_code=config.trust_remote_code
        )

    def _get_quantization_bits(self, config: ModelConfig) -> int:
        """Determine quantization bit width from config."""
        valid_bits = {
            BitPrecision.INT8.value,
            BitPrecision.INT4.value, 
            BitPrecision.INT3.value,
            BitPrecision.INT2.value,
            BitPrecision.INT1.value
        }
        default_bits = BitPrecision.INT4.value
        
        if not config.identifier.bits or config.identifier.bits in [BitPrecision.FP32, BitPrecision.FP16, BitPrecision.MIXED]:
            return default_bits
            
        bits = config.identifier.bits.value
        if bits not in valid_bits:
            logger.warning(f"Invalid bit width {bits}, defaulting to {default_bits}")
            return default_bits
        return bits

    def _create_quantization_config(self, bits: int, group_size: int = 64) -> HqqConfig:
        """Create HQQ configuration based on bit width."""
        return HqqConfig(
            nbits=bits,
            group_size=group_size,
            view_as_float=False,
            skip_modules=['lm_head']
        )
