import torch
import torch.nn as nn
import logging
from typing import Dict, Any
import importlib

# Configure logging
logger = logging.getLogger(__name__)

# A mapping from config names to quantizer classes
QUANTIZER_REGISTRY = {
    "benford-quant": "benford_quant.quantizer.benford_quantizer.BenfordQuantizer",
    "uniform-rtn": "benford_quant.quantizer.uniform_quantizer.UniformQuantizer",
}


def _tensor_nbytes(t: torch.Tensor) -> int:
    """Returns the size in bytes of a tensor."""
    return t.numel() * t.element_size()


def _calc_quantized_size(packed, state, scale_dtype=torch.float32) -> int:
    """Calculates real size in bytes of quantized representation."""
    packed_bytes = packed.numel() * packed.element_size()
    scales_bytes = state["scales"].numel() * torch.tensor([], dtype=scale_dtype).element_size()
    return packed_bytes + scales_bytes


class QuantizedLinear(nn.Module):
    """
    A wrapper for a quantized linear layer with memory accounting.
    """
    def __init__(self, original_layer: nn.Linear, quantizer):
        super().__init__()
        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features
        self.quantizer = quantizer

        # Quantize the weight and store the state
        weight_fp32 = original_layer.weight.detach().clone()
        self.quantized_weight, self.quant_state = self.quantizer.quantize(weight_fp32)

        # CPU-QUANT: Uncomment the following to move to GPU after quantization. You can keep this commented if you already have device_map = 'auto' in the model loading step.
        self.quantized_weight = self.quantized_weight.to('cuda')
        self.quant_state = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in self.quant_state.items()}

        # Save memory stats for reporting
        self.original_nbytes = _tensor_nbytes(weight_fp32)
        self.quantized_nbytes = _calc_quantized_size(self.quantized_weight, self.quant_state)

        # Bias is typically not quantized and kept in high precision
        self.bias = original_layer.bias.detach().clone() if original_layer.bias is not None else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        dequantized_weight = self.quantizer.dequantize(self.quantized_weight, self.quant_state)
        dequantized_weight = dequantized_weight.to(device=x.device)

        # QWEN-MODEL: Uncomment the following line when using bf16-based models
        dequantized_weight = dequantized_weight.to(dtype=x.dtype) 

        bias = self.bias.to(device=x.device) if self.bias is not None else None
        return nn.functional.linear(x, dequantized_weight.to(dtype=x.dtype), bias)

    def __repr__(self):
        return (f"QuantizedLinear(in_features={self.in_features}, "
                f"out_features={self.out_features}, "
                f"bias={self.bias is not None}, "
                f"quantizer={self.quantizer.__class__.__name__})")


def _get_quantizer_class(quantizer_name: str):
    if quantizer_name not in QUANTIZER_REGISTRY:
        raise ValueError(f"Unknown quantizer '{quantizer_name}'. Available: {list(QUANTIZER_REGISTRY.keys())}")

    module_path, class_name = QUANTIZER_REGISTRY[quantizer_name].rsplit('.', 1)
    module = importlib.import_module(module_path)
    return getattr(module, class_name)


def apply_quantization(model: nn.Module, config: Dict[str, Any], quant_stats: Dict[str, Any] = None):
    """
    Applies quantization to all linear layers of a model in-place
    and prints memory usage savings.
    """
    logger.info("Starting model quantization...")

    q_config = config.get('quantization', {})
    method = q_config.get('method')
    if not method:
        raise ValueError("Quantization 'method' not specified in config.")

    # Factory logic to get the quantizer class
    QuantizerClass = _get_quantizer_class(method)

    exclude_keys = {'method', 'calibrate', 'calibration_method', 'calibration_value'}
    quantizer_args = {k: v for k, v in q_config.items() if k not in exclude_keys}
    if 'weight_bits' in quantizer_args:
        quantizer_args['n_bits'] = quantizer_args.pop('weight_bits')

    # Optional calibration
    if method == 'benford-quant' and q_config.get('calibrate', False):
        calib_method = q_config.get('calibration_method', 'percentile')
        calib_value = q_config.get('calibration_value', 0.999)
    
        # Only nn.Linear layers
        weights = [m.weight.detach() for m in model.modules() if isinstance(m, nn.Linear)]
    
        # Sampling
        sampled = [w.flatten()[torch.randint(0, w.numel(), (min(100_000, w.numel()),), device=w.device)] 
                   for w in weights]
    
        all_weights = torch.cat(sampled, dim=0).to(torch.float32)
    
        # Min_exp estimation
        min_exp = QuantizerClass.estimate_range(all_weights, method=calib_method, value=calib_value)
    
        quantizer_args['min_exponent'] = min_exp
        logger.info(f"Pre-calibrated min_exponent: {min_exp}")


    quantizer = QuantizerClass(**quantizer_args)
    logger.info(f"Initialized {QuantizerClass.__name__} with args: {quantizer_args}")

    modules_to_replace = {name: module for name, module in model.named_modules() if isinstance(module, nn.Linear)}

    if not modules_to_replace:
        logger.warning("No nn.Linear layers found to quantize.")
        return model

    total_orig, total_quant = 0, 0
    for name, module in modules_to_replace.items():
        if '.' in name:
            parent_name, child_name = name.rsplit('.', 1)
            parent_module = model.get_submodule(parent_name)
        else:
            parent_module = model
            child_name = name

        logger.info(f"Replacing layer: {name} with QuantizedLinear")
        quantized_module = QuantizedLinear(module, quantizer)
        setattr(parent_module, child_name, quantized_module)

        # accumulate memory stats
        total_orig += quantized_module.original_nbytes
        total_quant += quantized_module.quantized_nbytes

        logger.info(
            f"Layer {name}: "
            f"orig={quantized_module.original_nbytes/1024**2:.2f}MB, "
            f"quant={quantized_module.quantized_nbytes/1024**2:.2f}MB, "
            f"saving={quantized_module.original_nbytes/quantized_module.quantized_nbytes:.2f}x"
        )

    logger.info("Model quantization complete.")
    logger.info(
        f"Total original={total_orig/1024**2:.2f}MB, "
        f"Total quantized={total_quant/1024**2:.2f}MB, "
        f"Overall saving={total_orig/total_quant:.2f}x"
    )

    if quant_stats is not None:
        quant_stats['total_original_megabytes'] = total_orig/1024**2
        quant_stats['total_quantized_megabytes'] = total_quant/1024**2
        quant_stats['overall_saving_times'] = total_orig / total_quant
        
    return model
