import gc
import logging
from typing import Optional, Any
import torch

import bitsandbytes as bnb
import bitsandbytes.functional

from bof4.quantization.quantization import Quantizer, QuantState

_logger = logging.getLogger(__name__)


class BnbQuantState(bitsandbytes.functional.QuantState):
    """Extends bitsandbytes.functional.QuantState by a dictionary of attributes of any kind."""

    def __init__(self, *args, attributes: Optional[dict[str, Any]] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.attributes = attributes or {}


def patch_bnb(quantizer: Quantizer):
    """Patch bitsandbytes to use a custom `Quantizer` instead of NF4"""
    replace_quant_functions(*wrap_quantizer(quantizer))


def replace_quant_functions(quant_func, dequant_func, gemv):
    bnb.functional.quantize_4bit = quant_func
    bnb.functional.dequantize_4bit = dequant_func
    bnb.functional.gemv_4bit = gemv


def wrap_quantizer(quantizer: Quantizer):
    """ "wrap a quantizer to create three closures with the same interface as `bitsandbytes.functional.quantize_4bit`
    and `bitsandbytes.functional.dequantize_4bit` bitsandbytes.functional.gemv_4bit` and respectively
    """
    @torch.inference_mode()
    def quantize(
        A: torch.Tensor,
        absmax: Optional[torch.Tensor] = None,
        out: Optional[torch.Tensor] = None,
        blocksize=64,
        compress_statistics=False,
        quant_type="nf4",
        quant_storage=torch.uint8,
    ):
        if A.device.type != "cuda":
            raise NotImplementedError(
                f"Device type not supported for FP4 quantization: {A.device.type}"
            )
        if quant_type not in ["nf4"]:
            raise NotImplementedError(
                f"4-bit quantization data type {quant_type} is not implemented."
            )

        n = A.numel()
        input_shape = A.shape

        assert (
            n % blocksize == 0
        ), f"Current implementation requires tensor size to be divisible by blocksize, {n} / {blocksize}"

        if absmax is None:
            blocks = n // blocksize
            blocks += 1 if n % blocksize > 0 else 0
            absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)

        bitsandbytes.functional.is_on_gpu([A, out, absmax])

        quantizer.to(A.device)
        quant_state = quantizer.quantize(A)

        if out is not None:
            out = out.reshape(quant_state.quant_values.shape)
            out[:] = quant_state.quant_values
        else:
            out = quant_state.quant_values

        code = getattr(quantizer, "reconstruction_values", None)

        if compress_statistics:
            _logger.warning(
                "Called with double quant. Double quant is not supported. Using single quant."
            )
            # raise NotImplementedError("Double quant not implemented yet")

        bnb_quant_state = BnbQuantState(
            absmax=quant_state.constants.flatten(),
            shape=input_shape,
            dtype=quant_state.original_dtype,
            blocksize=blocksize,
            code=code,
            quant_type=quant_type,
            attributes=quant_state.attributes,
        )

        return out, bnb_quant_state

    @torch.inference_mode()
    def dequantize(
        A: torch.Tensor,
        bnb_quant_state: Optional[bitsandbytes.functional.QuantState] = None,
        absmax: Optional[torch.Tensor] = None,
        out: Optional[torch.Tensor] = None,
        blocksize: int = 64,
        quant_type="nf4",
    ) -> torch.Tensor:
        if quant_type not in ["nf4"]:
            raise NotImplementedError(
                f"4-bit quantization data type {quant_type} is not implemented."
            )

        if bnb_quant_state is None:
            assert absmax is not None and out is not None

            py_quant_state = QuantState(
                A,
                constants=absmax,
                original_shape=A.shape,
                original_dtype=out.dtype,
                quantizer=quantizer,
            )

        else:
            py_quant_state = QuantState(
                A,
                constants=bnb_quant_state.absmax.reshape(-1, 1),
                original_shape=bnb_quant_state.shape,
                original_dtype=bnb_quant_state.dtype,
                quantizer=quantizer,
            )
            if isinstance(bnb_quant_state, BnbQuantState):
                py_quant_state.attributes = bnb_quant_state.attributes

        if py_quant_state.constants.shape[0] != A.shape[0]:
            # Bitsandbytes can transpose the input tensor during matmul
            # Transpose back and remember to transpose the output
            transposed = True
            py_quant_state.quant_values = A.t()
        else:
            transposed = False

        if out is not None:
            assert out.shape == py_quant_state.original_shape
            out = torch.empty(
                bnb_quant_state.shape, dtype=bnb_quant_state.dtype, device=A.device
            )
            out[:] = quantizer.dequantize(py_quant_state)
            return out
        else:
            out = quantizer.dequantize(py_quant_state)
            if transposed:
                return out.t()

            return out

    @torch.no_grad()
    def gemv(
        A: torch.Tensor,
        B: torch.Tensor,
        out: Optional[torch.Tensor] = None,
        transposed_A=False,
        transposed_B=False,
        state=None,
    ):
        B = dequantize(B, state)

        if len(A.shape) == 3:
            out_shape = (A.shape[0], A.shape[1], B.shape[0])
        else:
            out_shape = ((A.shape[0], B.shape[0]),)

        if out is None:
            out = B.matmul(A.flatten()).reshape(out_shape)
        else:
            torch.matmul(B, A.flatten(), out=out.flatten())

        return out

    return quantize, dequantize, gemv
