import logging
from typing import Any

import torch
from accelerate import dispatch_model
from accelerate.hooks import remove_hook_from_submodules
from torch import nn

from ..entquant.tensor import rebuild_tensors
from ..utils import clear_cache, get_device

logger = logging.getLogger(__name__)


def get_memory_stats(label: str = "") -> tuple[float, float]:
    """Get memory stats for all GPUs."""
    if not torch.cuda.is_available():
        return 0.0, 0.0

    total_a, total_r = 0.0, 0.0
    n_devices = torch.cuda.device_count()

    for i in range(n_devices):
        a = torch.cuda.memory_allocated(i) / 1024**3
        r = torch.cuda.memory_reserved(i) / 1024**3
        logger.debug(f"[{label:45}] GPU {i}: Alloc: {a:.3f} GiB | Rsrvd: {r:.3f} GiB | Gap: {r-a:.3f} GiB")
        total_a += a
        total_r += r

    if n_devices > 1:
        logger.debug(
            f"[{label:45}] Total: Alloc: {total_a:.3f} GiB | Rsrvd: {total_r:.3f} GiB | Gap: {total_r-total_a:.3f} GiB"
        )

    return total_a, total_r


def dispatch_helper(model: nn.Module, device_map: dict[str, str | torch.device | int], **dispatch_kwargs: Any):
    """Helper function to dispatch model to devices, based on accelerate's dispatch_model."""
    remove_hook_from_submodules(model)
    # Move modules to devices before accelerate's dispatching, otherwise it might fail for custom Parameter classes
    for module_name, device in device_map.items():
        model.get_submodule(module_name).to(device)
        logger.debug(f"Moved {module_name} to {device}")
    rebuild_tensors(model)

    # HOTFIX: dispatch_model will allocate as much memory as the base model, as it is not aware quantized and/or
    # compressed weights with special storage pointers. This allocation is not required of course.
    # Single GPU: use .to() (no-op for injected weights since they already point to decompression buffer)
    # Multi-GPU: use dispatch_model for proper hooks and offloading
    unique_devices = set(get_device(d) for d in device_map.values() if d not in ("cpu", "disk"))
    if len(unique_devices) <= 1:
        target_device = unique_devices.pop() if unique_devices else "cpu"
        model.to(target_device)
        logger.debug(f"Simple dispatch: model.to({target_device})")
    else:
        dispatch_model(model, device_map, **dispatch_kwargs)

    clear_cache()
