import json
import os
from typing import Tuple
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoConfig,
    QuantoConfig
)
from safetensors.torch import load_file
from accelerate import init_empty_weights
from src.loggers.setup_logging import setup_logging
from src.model_loading.common.config.model_config import ModelConfig
from src.model_loading.loaders.interface import ModelLoaderInterface

logger = setup_logging()

class QuantoModelLoader(ModelLoaderInterface):
    """Loader for QUANTO quantized models supporting both local and HuggingFace loading."""
    def load_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        logger.info(f"Loading QUANTO model with identifier: {config.identifier}")
        if config.identifier.is_local:
            return self._load_local_model(
                config=config,
                compile_mode=compile_mode,
                fullgraph=fullgraph,
                dynamic=dynamic
            )

        return self._load_hf_model(
            config=config,
            compile_mode=compile_mode,
            fullgraph=fullgraph,
            dynamic=dynamic
        )
    
    def _load_local_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a locally quantized QUANTO model using optimum.quanto."""
        logger.debug(f"Loading local QUANTO model from: {config.model_path}")
        
        # Verify required files exist
        safetensors_path = os.path.join(config.model_path, "model.safetensors")
        quantmap_path = os.path.join(config.model_path, "quantization_map.json")
        
        if not os.path.exists(safetensors_path):
            raise FileNotFoundError(f"Model weights not found at: {safetensors_path}")
        if not os.path.exists(quantmap_path):
            raise FileNotFoundError(f"Quantization map not found at: {quantmap_path}")
        
        # Load model components
        state_dict = load_file(safetensors_path)
        with open(quantmap_path, 'r') as f:
            quantization_map = json.load(f)
            
        # Initialize model
        model_config = AutoConfig.from_pretrained(
            config.tokenizer_path,
            trust_remote_code=config.trust_remote_code,
            cache_dir=config.cache_dir
        )
        
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(
                model_config, 
                trust_remote_code=config.trust_remote_code
            )
        
        # Apply quantization
        from optimum.quanto import requantize
        requantize(model, state_dict, quantization_map, config.device)
        
        # Load tokenizer
        tokenizer = self._load_tokenizer(config)
        
        return model, tokenizer
    
    def _load_hf_model(
        self,
        config: ModelConfig,
        compile_mode: str = "reduce-overhead",
        fullgraph: bool = True,
        dynamic: bool = False
    ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """Load a QUANTO model directly from HuggingFace using transformers."""
        logger.debug(f"Loading HuggingFace QUANTO model: {config.model_path}")
        
        # Determine quantization bit width from config
        bits = config.identifier.bits.value if config.identifier.bits else 8
        quanto_config = QuantoConfig(weights=f"int{bits}")
        
        # Make a copy of max_memory with integer keys for GPU devices
        max_memory = None
        if config.max_memory is not None:
            max_memory = {}
            for key, value in dict(config.max_memory).items():
                # Convert string device IDs to integers if they're numeric
                if isinstance(key, str) and key.isdigit():
                    max_memory[int(key)] = value
                else:
                    max_memory[key] = value
        
        # Check if we should use multi-GPU loading
        is_multi_gpu = (
            max_memory is not None and
            isinstance(max_memory, dict) and
            len(max_memory) > 1
        )
        
        # Multi-GPU loading path
        if is_multi_gpu:
            logger.info(f"Loading model {config.model_path} across multiple devices using max_memory configuration")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                torch_dtype="auto",
                device_map="auto",  # Use automatic device mapping for multi-GPU
                max_memory=max_memory,
                quantization_config=quanto_config,
                trust_remote_code=config.trust_remote_code,
                cache_dir=config.cache_dir
            )
            # Skip compilation for multi-GPU models as it's not compatible
        else:
            logger.info(f"Loading model {config.model_path} on single device: {config.device}")
            model = AutoModelForCausalLM.from_pretrained(
                config.model_path,
                torch_dtype="auto",
                device_map=config.device,
                max_memory=max_memory,
                quantization_config=quanto_config,
                trust_remote_code=config.trust_remote_code,
                cache_dir=config.cache_dir
            )
            
            # Compile model if CUDA is available (only for single-GPU case)
            if torch.cuda.is_available() and config.apply_compile:
                logger.info(f"Compiling model with mode: {compile_mode}")
                model = torch.compile(
                    model,
                    mode=compile_mode,
                    fullgraph=fullgraph,
                    dynamic=dynamic
                )
        
        # Load tokenizer
        tokenizer = self._load_tokenizer(config)
        
        return model, tokenizer
    
    def _load_tokenizer(self, config: ModelConfig) -> AutoTokenizer:
        """Load tokenizer with common configuration."""
        return AutoTokenizer.from_pretrained(
            config.tokenizer_path,
            device_map=config.device,
            cache_dir=config.cache_dir,
            trust_remote_code=config.trust_remote_code
        )
