import bitsandbytes.functional as F
import torch
import torch.nn as nn
from accelerate.utils import CustomDtype
from bitsandbytes.nn import Linear4bit, Params4bit
from transformers import BitsAndBytesConfig

from ..conditions import linear_precondition
from .quantized import QuantizedLayer


class Portable4BitLinear(QuantizedLayer, Linear4bit):
    def __init__(
        self,
        in_features,
        out_features,
        bias,
        compute_dtype: torch.dtype = torch.bfloat16,
        compress_statistics: bool = False,
        quant_type: str = "nf4",
        **kwargs,
    ):
        """
        Creates an empty layer. This is useful for quantization while loading the model
        """
        super(Portable4BitLinear, self).__init__(
            in_features,
            out_features,
            bias,
            compute_dtype=compute_dtype,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
            **kwargs,
        )
        self.is_quantized = quant_type
        self.original_dtype = compute_dtype
        self.has_bias = bias

    @classmethod
    def empty_init(
        cls,
        in_features,
        out_features,
        bias,
        qconfig: BitsAndBytesConfig,
    ):
        """
        Creates an empty layer. This is useful for quantization while loading the model
        """
        assert isinstance(qconfig, SSBitsAndBytesConfig), TypeError(
            "qconfig is incorrect type.\n"
            f"Expected `SSBitsAndBytesConfig`, got {str(qconfig.__class__)}"
        )
        instance = cls(
            in_features,
            out_features,
            bias,
            quant_type=qconfig.bnb_4bit_quant_type,
            compute_dtype=qconfig.original_dtype,
            compress_statistics=qconfig.bnb_4bit_use_double_quant,
        )
        return instance

    @classmethod
    def from_linear(
        cls,
        existing_layer: nn.Linear,
        is_quantized="nf4",
        compress_statistics=False,
    ):
        """
        Creates a quantized layer from an existing portable layer
        """
        assert is_quantized in [
            "fp4",
            "nf4",
        ], "Unsupported quantization type. Choose either 'fp4' or 'nf4'."

        existing_layer = linear_precondition(existing_layer)
        in_features, out_features = (
            existing_layer.in_features,
            existing_layer.out_features,
        )
        has_bias = existing_layer.bias is not None

        instance = cls(
            in_features,
            out_features,
            has_bias,
            compress_statistics=compress_statistics,
            compute_dtype=existing_layer.weight.dtype,
            quant_type=is_quantized,
        )

        instance.load_state_dict(
            state_dict=existing_layer.state_dict(),
            strict=False,
        )
        del existing_layer

        return instance

    def dequantize(self) -> torch.Tensor:
        """
        Dequantizes the quantized tensor.

        Parameters:
        - A: torch.Tensor, the quantized tensor to be dequantized.
        """
        with torch.no_grad():
            if self.weight.quant_state is None:
                raise ValueError(
                    "Weight quantization state is not initialized. Please quantize before dequantizing."
                )

            # Use the existing quant_state for dequantization
            dequantized_w = F.dequantize_4bit(
                self.weight, quant_state=self.weight.quant_state
            )
            return dequantized_w

    def quantize(self, layer: nn.Linear):
        """
        Converts input layer to a Portable4BitLinear
        Takes in a linear layer and returns quantlized linear layer
        """
        return Portable4BitLinear(layer, self.is_quantized, self.has_bias).to(
            self.existing_layer.weight.device
        )


class SSBitsAndBytesConfig(BitsAndBytesConfig):
    """
    Thin wrapper around `BitsAndBytesConfig`

    This class helps enable the SlimscaleQuantizer for quantization on the fly
    with the BitsAndBytes library
    """

    target_class = Portable4BitLinear
    param_class = Params4bit
    target_dtype = CustomDtype.INT4
    is_serializable = True
    is_trainable = True

    def __init__(self, original_dtype: "torch.dtype", **kwargs):
        super().__init__(
            bnb_4bit_compute_dtype=original_dtype,
            **kwargs,
        )
        self.original_dtype = original_dtype
