import ctypes as ct
from functools import reduce  # Required in Python 3
import itertools
import operator
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch import Tensor

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from bitsandbytes.cextension import lib
from bitsandbytes.functional import pre_call, is_on_gpu, get_ptr, post_call, quantize_blockwise, QuantState


dtype2bytes = {}
dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2
dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1


def get_4bit_type(typename, device=None, blocksize=64):
    if device is None:
        device = "cuda"
    data = None
    if typename == "nf4":
        """ Implements the NF4 data type.

            Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
            is normalized into the range [-1, 1].

            For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)

            Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
            the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
        """
        data = [
            -1.0,
            -0.6961928009986877,
            -0.5250730514526367,
            -0.39491748809814453,
            -0.28444138169288635,
            -0.18477343022823334,
            -0.09105003625154495,
            0.0,
            0.07958029955625534,
            0.16093020141124725,
            0.24611230194568634,
            0.33791524171829224,
            0.44070982933044434,
            0.5626170039176941,
            0.7229568362236023,
            1.0,
        ]
    elif typename == "fp4":
        # 0b000 = 0
        # 0b001 = 0.0625
        # 0b010 = 8
        # 0b011 = 12
        # 0b100 = 4
        # 0b101 = 6
        # 0b110 = 2
        # 0b111 = 3
        # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
        data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
    elif typename == "int4":
        data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
    elif typename == "af4":
        # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
        # https://arxiv.org/abs/2306.06965
        if blocksize == 64:
            data = [
                -1.0,
                -0.69441008,
                -0.51243739,
                -0.3736951,
                -0.25607552,
                -0.14982478,
                -0.04934812,
                0.0,
                0.04273164,
                0.12934483,
                0.21961274,
                0.31675666,
                0.42563882,
                0.55496234,
                0.72424863,
                1.0,
            ][::-1]
        else:
            raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.")

    if data is None:
        raise NotImplementedError(f"Typename {typename} not supported")

    data = torch.tensor(data, device=device)
    data.div_(data.abs().max())

    assert data.numel() == 16

    return data


def quantize_4bit(
    A: Tensor,
    absmax: Optional[torch.Tensor] = None,
    out: Optional[torch.Tensor] = None,
    blocksize=64,
    compress_statistics=False,
    quant_type="fp4",
    quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]:
    """
    Quantize tensor A in blocks of 4-bit values.

    Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.

    Parameters
    ----------
    A : torch.Tensor
        The input tensor.
    absmax : torch.Tensor
        The absmax values.
    out : torch.Tensor
        The output tensor.
    blocksize : int
        The blocksize used in quantization.
    quant_type : str
        The 4-bit quantization data type {fp4, nf4}

    Returns
    -------
    torch.Tensor:
        Tensor with packed 4-bit values.
    tuple(torch.Tensor, torch.Size, torch.dtype, int):
        The quantization state to undo the quantization.
    """
    if A.device.type != "cuda":
        raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
    if quant_type not in ["fp4", "nf4"]:
        raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")

    n = A.numel()
    input_shape = A.shape

    if absmax is None:
        blocks = n // blocksize
        blocks += 1 if n % blocksize > 0 else 0
        absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)

    if out is None:
        mod = dtype2bytes[quant_storage] * 2
        out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

    assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]

    prev_device = pre_call(A.device)
    is_on_gpu([A, out, absmax])

    if A.dtype == torch.float32:
        if quant_type == "fp4":
            lib.cquantize_blockwise_fp32_fp4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
        else:
            lib.cquantize_blockwise_fp32_nf4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
    elif A.dtype == torch.float16:
        if quant_type == "fp4":
            lib.cquantize_blockwise_fp16_fp4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
        else:
            lib.cquantize_blockwise_fp16_nf4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
    elif A.dtype == torch.bfloat16:
        if quant_type == "fp4":
            lib.cquantize_blockwise_bf16_fp4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
        else:
            lib.cquantize_blockwise_bf16_nf4(
                get_ptr(None),
                get_ptr(A),
                get_ptr(absmax),
                get_ptr(out),
                ct.c_int32(blocksize),
                ct.c_int(n),
            )
    else:
        raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
    post_call(A.device)

    code = get_4bit_type(quant_type, device=A.device)

    if compress_statistics:
        offset = absmax.mean()
        absmax -= offset
        qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
        del absmax
        state = QuantState(
            absmax=qabsmax,
            shape=input_shape,
            dtype=A.dtype,
            blocksize=blocksize,
            code=code,
            quant_type=quant_type,
            offset=offset,
            state2=state2,
        )
    else:
        state = QuantState(
            absmax=absmax,
            shape=input_shape,
            dtype=A.dtype,
            blocksize=blocksize,
            code=code,
            quant_type=quant_type,
        )

    return out, state