import logging
from collections import defaultdict
from typing import Any, Callable

import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import remove_hook_from_submodules
from torch import dtype, nn, Tensor

from ..entquant.tensor import get_tensor_data, rebuild_tensors
from ..utils import clear_cache, get_device, get_matching_module_names, get_matching_param_names
from .backend import Backend, nvCOMPBackend
from .utils import get_memory_stats

__all__ = ["CompressionManager"]

logger = logging.getLogger(__name__)


class DecompressionBuffer(dict):
    """
    Special dictionary, indexed by torch devices, allocating a decmpression buffer of suitable size
    based on allocation_fn.
    """

    def __init__(self, allocation_fn: Callable[[torch.device], Tensor]):
        super().__init__()
        self.allocation_fn = allocation_fn

    def __missing__(self, key):
        res = self.allocation_fn(key)
        self[key] = res
        return res


class CompressionManager:
    """Manages compression and on-the-fly decompression of model weights.

    Compresses selected weights using a configurable backend (default: nvCOMP) and
    registers forward hooks to decompress them into a shared buffer before layer execution.
    Supports multi-device dispatch with automatic device mapping.
    weights_include and weights_exclude determine which weights will be compressed.
    target_layer_include and target_layer_exclude determine which weights will be compressed and decompressed
    jointly. The decompression hooks are registered for the target layers and the not the weights directly.
    This allows for more efficient decompression, as the decompression buffer can be shared across multiple weights.
    The default is to select the target layers to be individual transformer blocks.

    Args:
        model: The model whose weights will be compressed.
        target_layer_include: fnmatch patterns for layers to target.
        target_layer_exclude: fnmatch patterns for layers to exclude.
        weight_include: fnmatch patterns for weights to compress.
        weight_exclude: fnmatch patterns for weights to exclude.
        device_host: Device for storing compressed weights before dispatch.
        device_compute: Device for compression operations.
        device_map: Device mapping for model dispatch ('auto' or explicit mapping).
        dtype_compressed: Data type for compressed representation.
        backend: Compression backend instance.
    """

    def __init__(
        self,
        model: nn.Module,
        target_layer_include: str | list[str] | None = None,
        target_layer_exclude: str | list[str] | None = None,
        weight_include: str | list[str] | None = None,
        weight_exclude: str | list[str] | None = None,
        device_host: str | torch.device | int = "cpu",
        device_compute: str | torch.device | int = "cuda",
        device_map: str | dict[str, str | torch.device | int] | None = None,
        dtype_compressed: dtype = torch.uint8,
        backend: Backend | None = None,  # None defaults to nvCOMPBackend (lazy loading to avoid unnecessary compile)
    ):
        super().__init__()

        self.device_host = get_device(device_host)
        self.device_compute = get_device(device_compute)
        self.device_map = device_map

        self.model = model

        self.target_layers: dict[str, tuple[nn.Module, dict[str, Any]]] = {}
        self.injected_weights: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict)
        self._init_target_layers(target_layer_include, target_layer_exclude, weight_include, weight_exclude)

        self.dtype_compressed = dtype_compressed
        self.backend = backend if backend is not None else nvCOMPBackend()

        self._decompression_buffer = DecompressionBuffer(self._allocate_decompression_buffer)
        self._decompression_hook_handles: dict[str, Any] = {}
        self._compressed = False

    def _get_weight(self, weight_name: str, return_data: bool = True) -> Tensor:
        """Get the weight (data) of a parameter or buffer (direct pointer to tensor data, no copy)."""
        weight = dict(self.model.named_parameters()).get(weight_name)
        weight = dict(self.model.named_buffers()).get(weight_name) if weight is None else weight
        return get_tensor_data(weight) if return_data else weight

    def _init_target_layers(
        self,
        target_layer_include: str | list[str] | None = None,
        target_layer_exclude: str | list[str] | None = None,
        weight_include: str | list[str] | None = None,
        weight_exclude: str | list[str] | None = None,
    ):
        # Get target layer names using fnmatch patterns
        target_layer_names = get_matching_module_names(
            self.model,
            include=target_layer_include,
            exclude=target_layer_exclude,
        )
        for target_layername in target_layer_names:
            self.target_layers[target_layername] = (
                self.model.get_submodule(target_layername),
                {"original_size": None, "compressed_size": None},  # dummies, will be set later
            )

        injected_weights_names = get_matching_param_names(
            self.model,
            include=weight_include,
            exclude=weight_exclude,
            include_buffers=True,
        )
        # find target layer name for each injected weight name
        for weight_name in injected_weights_names:
            # Find which target layer this weight belongs to
            for target_layer_name in self.target_layers.keys():
                if weight_name.startswith(target_layer_name + "."):
                    weight_data = self._get_weight(weight_name)
                    # TODO: We may need to make the meta info serializable for saving the model later
                    self.injected_weights[target_layer_name][weight_name] = {
                        "shape": weight_data.shape,
                        "dtype": weight_data.dtype,
                        "size": weight_data.numel() * weight_data.element_size(),
                        "byte_offset": None,  # dummy, will be set later
                    }

                    break
            else:
                raise ValueError(f"Weight {weight_name} does not belong to any target layer.")

    def _is_injected(self, name: str) -> bool:
        """Helper to check if a parameter/buffer is injected."""
        return any(name in weights for weights in self.injected_weights.values())

    def _compute_max_buffer_size(self, device: torch.device) -> int:
        """Compute the maximum buffer size required for decompressing the weights of a full target layer."""
        max_buffer_size = 0
        for weights in self.injected_weights.values():
            layer_size = sum(
                meta["size"] for weight_name, meta in weights.items() if self._get_weight(weight_name).device == device
            )
            max_buffer_size = max(max_buffer_size, layer_size)
        return max_buffer_size

    def _allocate_decompression_buffer(self, device: torch.device) -> Tensor:
        """
        Allocate a buffer for decompressing the weights of a full target layer.
        Used as allocation_fn in DecompressionBuffer.
        """
        max_buffer_size = self._compute_max_buffer_size(device)
        logger.debug(f"Allocating decompression buffer for device {device} with size {max_buffer_size / 1024 ** 2} MiB")
        return torch.empty(max_buffer_size, dtype=self.dtype_compressed, device=device)

    def _setup_weight_pointers(self, target_layers_include: list[str] | None = None):
        """
        Point weight tensors into decompression buffer using stored metadata.
        In this way, the weights are not copied, but only referenced.
        After decompression, a forward pass can be executed as usual.
        """
        for target_layer_name, weights in self.injected_weights.items():
            if target_layers_include is not None and target_layer_name not in target_layers_include:
                continue
            for weight_name, meta in weights.items():
                byte_offset = meta["byte_offset"]
                dtype = meta["dtype"]
                shape = meta["shape"]
                size = meta["size"]
                # retrieve weight directly from model because the device may have changed
                weight_data = self._get_weight(weight_name)
                # Point the original tensor's storage to the decompression buffer
                buffer_slice = self._decompression_buffer[weight_data.device][byte_offset : byte_offset + size]
                weight_data.set_(buffer_slice.view(dtype).view(shape))
                logger.debug(
                    f"Pointed {weight_name} to decompression buffer [{byte_offset}:{byte_offset + size}] "
                    f"on {weight_data.device}"
                )

    def _remove_decompression_hooks(self):
        for handle in self._decompression_hook_handles.values():
            handle.remove()
        self._decompression_hook_handles.clear()

    def _register_decompression_hooks(self):
        """Register forward pre-hooks to decompress weights before layer execution."""
        self._remove_decompression_hooks()  # remove existing decompression hooks if any

        for target_layer_name, (target_layer, _) in self.target_layers.items():

            def create_hook(layer_name: str):
                def hook(module, inp):
                    self.decompress(layer_name)

                return hook

            handle = target_layer.register_forward_pre_hook(create_hook(target_layer_name))
            self._decompression_hook_handles[target_layer_name] = handle
            logger.debug(f"Registered hook for {target_layer_name}")

        logger.debug(f"Registered {len(self._decompression_hook_handles)} decompression hooks")

    def dispatch(self, device_map: dict[str, str | torch.device | int] | None = None, **dispatch_kwargs: Any):
        """
        Dispatch the compressed model to device(s).
        """
        if device_map == "auto":
            # make sure no target layers are split across devices
            no_split_modules = set(type(layer[0]).__name__ for layer in self.target_layers.values())
            if hasattr(self.model, "_no_split_modules"):
                no_split_modules.update(self.model._no_split_modules)  # noqa
            device_map = infer_auto_device_map(self.model, no_split_module_classes=list(no_split_modules))
            logger.debug(f"Device map is 'auto', inferred device map: {device_map}")
        elif device_map is None:
            device_map = {"": self.device_host}
            logger.debug(f"Device map is not provided, dispatching to {self.device_host}")

        self._decompression_buffer.clear()

        # Infer device for each target layer by finding the longest matching prefix in device_map
        device_map_target_layers = {}
        for target_layer_name in self.target_layers.keys():
            best_match = ""
            best_device = device_map.get("", self.device_host)  # fallback to root mapping or device_host
            for module_name, device in device_map.items():
                # Check if module_name is a prefix of target_layer_name (or exact match)
                if target_layer_name == module_name or target_layer_name.startswith(module_name + "."):
                    if len(module_name) > len(best_match):
                        best_match = module_name
                        best_device = device
            device_map_target_layers[target_layer_name] = get_device(best_device)

        # remove all existing hooks (from previous dispatch)
        remove_hook_from_submodules(self.model)

        # Move target layers to devices before dispatching, otherwise transfer might fail for custom Parameter classes
        for target_layer_name, device in device_map_target_layers.items():
            self.target_layers[target_layer_name][0].to(device)
            rebuild_tensors(self.target_layers[target_layer_name][0])
            self._setup_weight_pointers(target_layers_include=[target_layer_name])
            logger.debug(f"Moved {target_layer_name} to {device}")

        # HOTFIX: dispatch_model will allocate as much memory as the base model, as it is not aware quantized and/or
        # compressed weights with special storage pointers. This allocation is not required of course.
        # Single GPU: use .to() (no-op for injected weights since they already point to decompression buffer)
        # Multi-GPU: use dispatch_model for proper hooks and offloading
        unique_devices = set(get_device(d) for d in device_map.values() if d not in ("cpu", "disk"))
        if len(unique_devices) <= 1:
            target_device = unique_devices.pop() if unique_devices else self.device_host
            self.model.to(target_device)
            logger.debug(f"Simple dispatch: model.to({target_device})")
        else:
            dispatch_model(self.model, device_map, **dispatch_kwargs)

        self._register_decompression_hooks()

        clear_cache()
        logger.info(f"Model moved to devices {device_map}")

    def compress(self, target_layers_include: list[str] | None = None) -> None:
        if self._compressed:
            logger.warning("Model is already compressed, skipping compression.")
            return

        self.model.to(self.device_host)
        logger.info(f"Model moved to host device {self.device_host}")

        clear_cache()
        get_memory_stats("Before compression")

        target_layers_compressed_count = 0

        for target_layer_name, weights in self.injected_weights.items():
            if target_layers_include is not None and target_layer_name not in target_layers_include:
                continue

            weights_all = []
            byte_offset = 0

            # Collect flatten all injected weights of the target layer...
            for weight_name, meta in weights.items():
                meta["byte_offset"] = byte_offset
                weights_all.append(self._get_weight(weight_name).to(self.device_compute).flatten())
                byte_offset += meta["size"]

            if not weights_all:
                continue

            # ... and concatenate them
            weights_all_bytes = torch.cat(weights_all).view(self.dtype_compressed).contiguous()
            original_size = weights_all_bytes.numel() * weights_all_bytes.element_size()

            # Compress and get back the compressed tensor
            compressed_weights = self.backend.compress(weights_all_bytes)

            compressed_weights = compressed_weights.to(self.device_host)
            compressed_size = compressed_weights.numel() * compressed_weights.element_size()

            # Free up memory
            del weights_all, weights_all_bytes

            # Store compressed data as registered buffer
            self.target_layers[target_layer_name][0].register_buffer("_compressed_weights", compressed_weights)
            self.target_layers[target_layer_name][1]["original_size"] = original_size
            self.target_layers[target_layer_name][1]["compressed_size"] = compressed_size

            # Clear original weight data (keep them as empty though to allow for device transfer)
            # Crucial for memory management because otherwise the original weights would be moved to target devices
            self._setup_weight_pointers(target_layers_include=[target_layer_name])

            ratio = original_size / compressed_size if compressed_size > 0 else 0
            logger.info(
                f"Compressed {target_layer_name}: "
                f"{original_size / (1024**2):.2f} -> {compressed_size / (1024**2):.2f} MiB "
                f"({ratio:.2f}x)"
            )
            target_layers_compressed_count += 1

        self.backend.clear_cache()
        clear_cache()
        get_memory_stats("After compression (weights still on host device)")

        self.dispatch(self.device_map)
        get_memory_stats("After compression (model moved to devices)")

        logger.info(f"Compression complete: {target_layers_compressed_count} layers")
        self._compressed = True

    def decompress(self, target_layer: str):
        """Decompress a single layer's weights into the decompression buffer."""
        assert self._compressed, "Model is not compressed, please call compress() first."
        compressed_weights = self.target_layers[target_layer][0].get_buffer("_compressed_weights")
        self.backend.decompress(target_layer, compressed_weights, self._decompression_buffer[compressed_weights.device])
        # No explicit sync needed - decompress() uses CUDA events to make
        # PyTorch's stream wait for decompression without blocking CPU

    def compression_ratio(self) -> dict[str, float]:
        """Compute compression statistics for the compressed model.

        The original size is computed using each parameter/buffer's actual dtype.

        Returns a dictionary with the following keys:
        - ratio: Compression ratio for injected parameters only
        - full_ratio: Compression ratio across all parameters and buffers
        - injected_original_bytes: Total bytes of original injected parameters
        - injected_compressed_bytes: Total bytes of compressed injected parameters
        - full_original_bytes: Total bytes of all parameters/buffers (excluding _compressed buffers)
        - full_compressed_bytes: Total bytes after compression (excluding injected params)
        """
        if not self._compressed:
            raise ValueError("Model is not compressed, cannot calculate compression ratio.")

        # Compute injected parameters compression ratio
        # Count bytes in injected weights using each weight's actual dtype
        injected_original_bytes = 0
        injected_compressed_bytes = 0

        for target_layer_name, (target_layer, meta) in self.target_layers.items():
            compressed_size = meta["compressed_size"]
            if compressed_size is not None:
                # Count bytes from injected weights using their actual dtype
                for weight_name in self.injected_weights[target_layer_name].keys():
                    weight = self._get_weight(weight_name, return_data=False)
                    injected_original_bytes += weight.numel() * weight.element_size()
                injected_compressed_bytes += compressed_size

        # Count bytes for original model (exclude _compressed buffers and injected params)
        # Use each tensor's actual dtype
        full_original_bytes = injected_original_bytes
        for name, param in self.model.named_parameters():
            if not self._is_injected(name):
                full_original_bytes += param.numel() * param.element_size()

        for name, buffer in self.model.named_buffers():
            if not name.endswith("._compressed_weights") and not self._is_injected(name):
                full_original_bytes += buffer.numel() * buffer.element_size()

        # Count bytes for compressed model (include _compressed buffers, exclude injected params)
        full_compressed_bytes = 0
        for name, tensor in list(self.model.named_parameters()) + list(self.model.named_buffers()):
            if not self._is_injected(name):
                full_compressed_bytes += tensor.numel() * tensor.element_size()

        injected_ratio = injected_original_bytes / injected_compressed_bytes if injected_compressed_bytes > 0 else 0.0
        full_ratio = full_original_bytes / full_compressed_bytes if full_compressed_bytes > 0 else 0.0

        return {
            "ratio": injected_ratio,
            "full_ratio": full_ratio,
            "injected_original_bytes": injected_original_bytes,
            "injected_compressed_bytes": injected_compressed_bytes,
            "full_original_bytes": full_original_bytes,
            "full_compressed_bytes": full_compressed_bytes,
        }
