import logging

import torch
from optimum.quanto import QBytesTensor, WeightQBytesTensor
from optimum.quanto.tensor.packed import PackedTensor
from optimum.quanto.tensor.weights.marlin import MarlinF8PackedTensor, MarlinF8QBytesTensor
from torch import nn, Tensor

logger = logging.getLogger(__name__)


def get_tensor_data(tensor: Tensor) -> Tensor:
    """
    This helper function returns the data of a (quantized) tensor.
    The data is returned as a raw pointer to the underlying data, NOT a copy. This is crucial for the compression
    process in entquant/compress to work correctly.
    """
    if not isinstance(tensor, QBytesTensor):
        return tensor

    # QBytesTensor store their data here
    data = tensor._data

    # We cannot directly use unpack here because we may need to modify the data in place
    if isinstance(data, (MarlinF8PackedTensor, PackedTensor)):
        return data._data

    # If the data is float8_e4m3fn (when Marlin tensor was move to CPU), we need to pack it to get the packed data
    if data.dtype == torch.float8_e4m3fn:
        return MarlinF8PackedTensor.pack(data.cuda())._data.to(tensor.device)

    return data


def rebuild_marlin_f8_qbytes_tensor(
    tensor: WeightQBytesTensor,
) -> MarlinF8QBytesTensor | WeightQBytesTensor | nn.Parameter:
    """
    Rebuild a WeightQBytesTensor to MarlinF8QBytesTensor where applicable.
    This is required for GPU inference.
    """
    # This only works on CUDA
    if tensor.device.type != "cuda":
        return tensor

    # Already a MarlinF8QBytesTensor, nothing to do
    if isinstance(tensor, MarlinF8QBytesTensor):
        return tensor

    # Already a MarlinF8PackedTensor, only need to wrap it in a MarlinF8QBytesTensor.
    # If float8_e4m3fn (when Marlin tensor was move to CPU), MarlinF8PackedTensor will automatically pack it.
    if isinstance(tensor._data, MarlinF8PackedTensor) or tensor._data.dtype == torch.float8_e4m3fn:
        return MarlinF8QBytesTensor(
            tensor._qtype, tensor._axis, tensor.size(), tensor.stride(), tensor._data, tensor._scale
        )
    # Otherwise, nothing to do
    else:
        return tensor


def rebuild_tensors(module: nn.Module) -> None:
    """
    Rebuild all WeightQBytesTensor parameters/buffers in a model to MarlinF8QBytesTensor where applicable.
    """
    for name, param in module.named_parameters():
        if isinstance(param, WeightQBytesTensor):
            rebuilt = rebuild_marlin_f8_qbytes_tensor(param)
            if rebuilt is not param:
                *path, attr = name.split(".")
                parent_module = module.get_submodule(".".join(path)) if path else module
                parent_module._parameters[attr] = rebuilt
                logger.debug(f"Rebuilt parameter {name} to MarlinF8QBytesTensor")

    for name, buffer in module.named_buffers():
        if isinstance(buffer, WeightQBytesTensor):
            rebuilt = rebuild_marlin_f8_qbytes_tensor(buffer)
            if rebuilt is not buffer:
                *path, attr = name.split(".")
                parent_module = module.get_submodule(".".join(path)) if path else module
                parent_module._buffers[attr] = rebuilt
                logger.debug(f"Rebuilt buffer {name} to MarlinF8QBytesTensor")


def resolve_signed_zeros(tensor: Tensor) -> Tensor:
    """Resolve signed zeros in a float tensor."""
    if not isinstance(tensor, QBytesTensor) or not tensor.qtype.is_floating_point:
        return tensor

    if isinstance(tensor._data, MarlinF8PackedTensor):
        data = tensor._data.unpack()
        packed = True
    else:
        data = tensor._data
        packed = False

    # float casting because not all dtypes support these operations
    data_float = data.float()
    data_float[data_float == -0.0] = 0.0
    data = data_float.to(data.dtype)

    if packed:
        tensor._data = MarlinF8PackedTensor.pack(data.cuda()).to(tensor.device)
    else:
        tensor._data = data

    return tensor
