import torch

from lbmqt.conf import config
import lbmqt.cpp_extension.quantization as ext_quantization

__all__ = ['quantize_data', 'dequantize_data']


def quantize_and_pack_linear(data: torch.Tensor, absmax: torch.Tensor, bits: int, simulate: bool, signed: bool, stochastic: bool) -> torch.Tensor:
    r"""Performs a group quantization from fp to custom format with absmax linear method
    Params:
        data(torch.Tensor): (num_groups, group_size), the data to quantize
        absmax(torch.Tensor): (num_groups, 1), the grouped absmax of data
    Returns:
        the quantized tensor
    """
    if simulate:
        output = data # unquantized grouped tensor: num_groups, group_size
        B = ((2 ** (bits - 1)) - 1) if signed else ((2 ** bits) - 1)
        output = output / absmax * B
        if stochastic:
            noise = output.new(output.shape).uniform_(-0.5, 0.5)
            output.add_(noise)

        if signed:
            output.clamp_(-B, B)
        else:
            output.clamp_(0, B)
        output = output.round_().to(torch.int)
    else:
        output = ext_quantization.pack_absmax_linear(data, absmax, bits, stochastic)

    return output


def quantize_and_pack_linear_without_rounding(data: torch.Tensor, absmax: torch.Tensor, bits: int, signed=True) -> torch.Tensor:
    # NOTE: input group; output group
    output = data # unquantized grouped tensor: num_groups, group_size
    B = ((2 ** (bits - 1)) - 1) if signed else ((2 ** bits) - 1)
    output = output / absmax * B
    if signed:
        output.clamp_(-B, B)
    else:
        output.clamp_(0, B)
    return output


def dequantize_and_unpack_linear(data, absmax: torch.Tensor, num_groups: int, group_size: int, bits: int, simulate: bool, signed: bool) -> torch.Tensor:
    if simulate:
        B = ((2 ** (bits - 1)) - 1) if signed else ((2 ** bits) - 1)
        output = data / B * absmax
    else:
        output = ext_quantization.unpack_absmax_linear(data, bits, absmax, num_groups, group_size)

    return output


def quantize_and_pack_nonlinear(data: torch.Tensor, absmax: torch.Tensor, qmap: torch.Tensor, bits: int, simulate: bool, stochastic: bool) -> torch.Tensor:
    r"""Performs a group quantization from fp to custom format with absmax nonlinear method
    Params:
        data(torch.Tensor): (num_groups, group_size), the data to quantize
        absmax(torch.Tensor): (num_groups, 1), the grouped absmax of data
    Returns:
        the quantized tensor
    """
    if simulate:
        output = data # unquantized grouped tensor: (num_groups, group_size)
        upper = qmap[-1]
        lower = qmap[0]
        output = (output / absmax).clamp_(lower, upper)
        shape = output.shape

        if stochastic:
            output = output.view(-1, 1)
            B = 2 ** bits
            floor_idx = (output >= qmap).sum(dim=-1).clamp_max_(B - 1) - 1
            output = output.view(-1)
            proba = (output - qmap[floor_idx]) / (qmap[floor_idx + 1] - qmap[floor_idx])
            proba = torch.bernoulli(proba)
            idx = (floor_idx + proba).to(torch.int).view(shape)
            return idx
        else:
            output = output.view(-1, 1)
            idx = (output - qmap).abs().argmin(dim=-1).to(torch.int).view(shape)
            return idx
    else:
        return ext_quantization.pack_absmax_nonlinear(data, absmax, qmap, bits, stochastic)


def dequantize_and_unpack_nonlinear(data, absmax: torch.Tensor, qmap: torch.Tensor, num_groups: int, group_size: int, bits: int, simulate: bool) -> torch.Tensor:
    if simulate:
        output = qmap[data.to(torch.int64)] * absmax # num_groups, group_size
    else:
        output = ext_quantization.unpack_absmax_nonlinear(data, absmax, qmap, bits, num_groups, group_size)

    return output


def compute_absmax_per_group(input_groups, scale):
    r"""Compute the absmax of per group
    """
    absmax = input_groups.norm(p=float('inf'), dim=-1, keepdim=True) * scale # groups, 1
    return absmax


def compute_absmax_before_grouping(input: torch.Tensor, gp_sz, scale):
    r"""Group tensor then compute the absmax of per group
    """
    input_groups = group_tensor(input, gp_sz=gp_sz)
    absmax = compute_absmax_per_group(input_groups, scale)
    return absmax


def group_tensor(input: torch.Tensor, gp_sz: int):
    r"""Group tensor into subtensors of size 'gp_sz'
    """
    if not gp_sz > 0:
        raise ValueError("group size need to be a positive integer, but found {}".format(gp_sz))

    input_flatten = input.flatten()
    num_features = input_flatten.shape[0] 

    # Reshape the tensor into group
    if num_features % gp_sz != 0:
        # Padding
        new_num_features = (num_features // gp_sz + 1) * gp_sz
        delta = new_num_features - num_features
        input_flatten = torch.cat([input_flatten,
                                   torch.zeros([delta], dtype=input.dtype, device=input.device)], dim=0)

    input_groups = input_flatten.view(-1, gp_sz) # num_groups, group_size
    return input_groups


def recon_grouped_tensor(grouped_tensor: torch.Tensor, shape) -> torch.Tensor :
    r"""Reconstruction the tensor to original (or specific) shape
    """
    numel = shape.numel()
    recon_flatten = grouped_tensor.flatten()[:numel]
    recon = recon_flatten.view(shape)
    return recon


def quantize_data_general(input: torch.Tensor, q_config: dict) -> tuple:
    # initialize and get quantization arguments
    gp_sz, absmax, b = q_config['group_size'], q_config['absmax'], q_config['bits']
    num_mode = q_config['num_mode']
    signed = q_config['signed']
    qmap = q_config['qmap']['dynamic' if signed else 'udynamic']
    simulate = q_config['simulate']
    stochastic = q_config['stochastic']

    input = input.contiguous()
    input_groups = group_tensor(input, gp_sz=gp_sz)

    if config.debug_quantize_func:
        print(f'in quantized_data: {input_groups}, {absmax}, {input_groups.shape}')
        print(f'in quantized_data, qmap: {qmap}, len = {len(qmap)}')

    # quantize
    if num_mode == 'linear':
        q_input = quantize_and_pack_linear(input_groups, absmax, b, simulate, signed, stochastic)
    elif num_mode == 'nonlinear':
        q_input = quantize_and_pack_nonlinear(input_groups, absmax, qmap, b, simulate, stochastic)
    else:
        raise ValueError(f'numerical mode is not supported for quantization (got {num_mode})')

    if config.debug_quantize_func:
        print(f'in quantized_data: {q_input}, {absmax}, {q_input.shape}')
    return q_input


def dequantize_data_general(q_input: torch.Tensor, q_config: dict):
    # get dequantization arguments
    absmax, b, input_shape = q_config['absmax'], q_config['bits'], q_config['shape']
    num_mode = q_config['num_mode']
    signed = q_config['signed']
    qmap = q_config['qmap']['dynamic' if signed else 'udynamic']
    num_groups, group_size = absmax.shape[0], q_config['group_size']
    simulate = q_config['simulate']

    # dequantize
    if num_mode == 'linear':
        input_groups = dequantize_and_unpack_linear(q_input, absmax, num_groups, group_size, b, simulate, signed)
    elif num_mode == 'nonlinear':
        input_groups = dequantize_and_unpack_nonlinear(q_input, absmax, qmap, num_groups, group_size, b, simulate)
    else:
        raise ValueError(f'numerical mode is not supported for dequantization (got {num_mode})')
    input = recon_grouped_tensor(grouped_tensor=input_groups, shape=input_shape)

    if config.debug_quantize_func:
        print(input_shape)
        print(f'in dequantized_data: {q_input}, {absmax}, {q_input.shape}')
        print(f'in dequantized_data: {input}, {absmax}, {input.shape}')
    return input.contiguous()


def is_quantifiable(param, quantifiable_lower_bound: int) -> bool:
    return param.numel() >= quantifiable_lower_bound


def create_dynamic_map(signed=True, b=7, n=None):
    if n is None:
        n = b
    else:
        assert 0 < n and n <= b
    data = []
    additional_items = 2 ** (b - n) - 1
    if not signed:
        additional_items = 2 * additional_items
    for i in range(n):
        fraction_items = (
            2 ** (i + b - n) + 1 if signed else 2 ** (i + b - n + 1) + 1
        )
        boundaries = torch.linspace(0.1, 1, fraction_items)
        means = (boundaries[:-1] + boundaries[1:]) / 2.0
        data += ((10 ** (-(n - 1) + i)) * means).tolist()
        if signed:
            data += (-(10 ** (-(n - 1) + i)) * means).tolist()
    if additional_items > 0:
        boundaries = torch.linspace(0.1, 1, additional_items + 1)
        means = (boundaries[:-1] + boundaries[1:]) / 2.0
        data += ((10 ** (-(n - 1) + i)) * means).tolist()
        if signed:
            data += (-(10 ** (-(n - 1) + i)) * means).tolist()
    data.append(0)
    data.append(1.0)
    data.sort()
    return torch.Tensor(data)


global_dynamic_maps = {}
def get_dynamic_map(bits):
    if global_dynamic_maps.get(bits, None) is None:
        global_dynamic_maps[bits] = {
            'dynamic': create_dynamic_map(signed=True, b=bits-1).to(config.device),
            'udynamic': create_dynamic_map(signed=False, b=bits-1).to(config.device)
        }
        return global_dynamic_maps[bits]
    else:
        return global_dynamic_maps[bits]


def compute_outlier(input: torch.Tensor, q_config: dict):
    gp_sz, absmax = q_config['group_size'], q_config['absmax']
    input = input.contiguous()
    input_groups = group_tensor(input, gp_sz=gp_sz)
    
    outlier = torch.zeros_like(input_groups)
    diff1 = input_groups - absmax
    outlier[diff1 > 0] = diff1[diff1 > 0]
    diff2 = input_groups + absmax
    outlier[diff2 < 0] = diff2[diff2 < 0]
    num_outliers = (outlier != 0).sum()
    if num_outliers == 0:
        return None

    outlier = recon_grouped_tensor(outlier, input.shape).to_sparse()
    return outlier
