import gc
import logging
from fnmatch import fnmatch

import optimum.quanto
import torch
from torch import nn

logger = logging.getLogger(__name__)


def clear_cache(use_cuda: bool = True, use_gc: bool = True):
    """Clear the cache."""
    if use_gc:
        gc.collect()
    if use_cuda:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()

    logger.debug("Cache cleared.")


def get_device(device: str | torch.device | int) -> torch.device:
    """Get a torch device (hashable)."""
    return device if isinstance(device, torch.device) else torch.device(device)


def get_matching_module_names(
    model: nn.Module,
    include: str | list[str] | None = None,
    exclude: str | list[str] | None = None,
    layer_types: tuple[type[nn.Module]] | list[type[nn.Module]] | None = None,
) -> list[str]:
    """Get all module names matching the include/exclude patterns (fnmatch glob style)."""
    if include is not None:
        include = [include] if isinstance(include, str) else include
    if exclude is not None:
        exclude = [exclude] if isinstance(exclude, str) else exclude
    module_names = []
    for name, module in model.named_modules():
        if include is not None and not any(fnmatch(name, pattern) for pattern in include):
            continue
        if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
            continue
        if layer_types is not None and not any(isinstance(module, t) for t in layer_types):
            continue
        module_names.append(name)
    return module_names


def get_matching_param_names(
    model: nn.Module,
    include: str | list[str] | None = None,
    exclude: str | list[str] | None = None,
    include_buffers: bool = False,
) -> list[str]:
    """Get all parameter names matching the include/exclude patterns (fnmatch glob style)."""
    if include is not None:
        include = [include] if isinstance(include, str) else include
    if exclude is not None:
        exclude = [exclude] if isinstance(exclude, str) else exclude
    param_names = []
    for name, _ in model.named_parameters():
        if include is not None and not any(fnmatch(name, pattern) for pattern in include):
            continue
        if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
            continue
        param_names.append(name)
    if include_buffers:
        for name, _ in model.named_buffers():
            if include is not None and not any(fnmatch(name, pattern) for pattern in include):
                continue
            if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
                continue
            param_names.append(name)
    return param_names


def str_to_dtype(dtype_str: str) -> torch.dtype:
    """Convert a string to a torch dtype."""
    if hasattr(torch, dtype_str):
        return getattr(torch, dtype_str)
    else:
        raise ValueError(f"Unknown dtype: {dtype_str}")


def str_to_qtype(qtype_str: str) -> optimum.quanto.qtype:
    """Convert a string to a optimum.quanto.qtype."""
    if hasattr(optimum.quanto, qtype_str):
        return getattr(optimum.quanto, qtype_str)
    else:
        raise ValueError(f"Unknown qtype: {qtype_str}")
