import logging
from dataclasses import dataclass

import torch
from optimum.quanto import freeze, Optimizer, qint8, QLinear, qtype, quantize
from torch import nn

from ..super_weights.super_weights import SuperWeight
from ..utils import clear_cache, get_matching_module_names, str_to_qtype
from .optimizer import SymmetricEntropyOptimizer, WrappedAbsmaxOptimizer
from .tensor import get_tensor_data, resolve_signed_zeros
from .utils import entropy

logger = logging.getLogger(__name__)


@dataclass
class QuantoConfig:
    include: str | list[str] | None = None
    exclude: str | list[str] | None = None
    layer_types: tuple[type[nn.Module]] | list[type[nn.Module]] = (nn.Linear,)
    weight_qtype: qtype | str = qint8
    activation_qtype: qtype | str | None = None
    optimizer: Optimizer = SymmetricEntropyOptimizer()
    # For super weights, we use the basic absmax optimizer
    optimizer_super_weights: Optimizer = WrappedAbsmaxOptimizer()

    def __post_init__(self):
        if isinstance(self.weight_qtype, str):
            self.weight_qtype = str_to_qtype(self.weight_qtype)


class EntQuantManager:
    def __init__(
        self,
        model: nn.Module,
        quanto_config: QuantoConfig | None = None,
        super_weights: dict[str, list[SuperWeight]] | None = None,
    ):
        self.model = model
        self.quanto_config = quanto_config if quanto_config is not None else QuantoConfig()
        self.super_weights = super_weights

        if hasattr(model, "dtype") and isinstance(model.dtype, torch.dtype):
            self.base_dtype = model.dtype
        else:
            self.base_dtype = next(self.model.parameters()).dtype

        self._quantized = False

    def quantize(self):
        if self._quantized:
            logger.warning("Model is already quantized, skipping quantization.")
            return

        logger.info("Quantizing layers")

        self.model.eval()

        # create a list of all modules to be quantized, merge super weights with rest to avoid quantizing them twice
        super_weights = {} if self.super_weights is None else self.super_weights
        module_names_sw = set(super_weights.keys())
        module_names_rest = set(
            get_matching_module_names(
                self.model, self.quanto_config.include, self.quanto_config.exclude, self.quanto_config.layer_types
            )
        )
        module_names = module_names_sw.union(module_names_rest)
        # sort module names according to the order of model.named_modules()
        module_names = [name for name, module in self.model.named_modules() if name in module_names]

        for module_name in module_names:
            optimizer = (
                self.quanto_config.optimizer_super_weights
                if super_weights.get(module_name)  # excludes missing and empty super weights list
                else self.quanto_config.optimizer
            )
            quantize(
                self.model,
                optimizer=optimizer,
                weights=self.quanto_config.weight_qtype,
                include=f"*{module_name}*",
                activations=self.quanto_config.activation_qtype,
            )
            freeze(self.model)
            logger.info(f"Quantized {module_name}")

        # resolving signed zeros for floating point weights to eliminate unnecessary redundancies
        if self.quanto_config.weight_qtype.is_floating_point:
            for module_name, module in self.model.named_modules():
                if isinstance(module, QLinear):
                    module.weight = resolve_signed_zeros(module.weight)
            logger.info("Resolved signed zeros")

        self._quantized = True

        clear_cache()

        logger.info("Quantization completed")

    def entropy(self) -> dict[str, float]:
        """Calculate the entropy, sparsity and average number of unique values of the quantized model."""
        if not self._quantized:
            raise ValueError("Model is not quantized, cannot calculate entropy.")

        result = {}
        avg_ent = 0.0
        avg_sparsity = 0.0
        avg_unique_val = 0
        total_numel = 0
        for module_name, module in self.model.named_modules():
            if isinstance(module, QLinear):
                numel = module.weight.numel()
                total_numel += numel

                ent, val, _ = entropy(get_tensor_data(module.weight).view(torch.uint8), return_val_p=True)
                ent = ent.item()
                avg_unique_val += val.numel() * numel
                avg_ent += ent * numel
                result[f"{module_name}/entropy"] = ent

                sparsity = (module.weight == 0).float().mean().item()
                avg_sparsity += sparsity * numel
                result[f"{module_name}/sparsity"] = sparsity

        result["average_entropy"] = avg_ent / total_numel
        result["average_sparsity"] = avg_sparsity / total_numel
        result["average_unique_val"] = avg_unique_val / total_numel
        return result
