import torch
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _symmetric_uniform_quantization(x, nbits, stochastic=False):
    assert (torch.isnan(x).sum() == 0)
    assert (torch.isinf(x).sum() == 0)

    c = torch.max(torch.abs(x))
    s = c / (2**(nbits - 1) - 1)
    if s == 0:
        return x, s
    c_minus = c * -1.0

    # qx = torch.where(x.ge(c), c, x)
    # qx = torch.where(qx.le(c_minus), c_minus, qx)
    # qx.div_(s)
    qx = x / s

    if stochastic:
        noise = qx.new(qx.shape).uniform_(-0.5, 0.5)
        qx.add_(noise)

    qx.clamp_(-(2**(nbits - 1) - 1), (2**(nbits - 1) - 1)).round_()
    return qx, s


def symmetric_uniform_quantization(state_dict, nbits=8):
    """
    Perform symmetric uniform quantization to weight in conv & fc layers
    Args:
        state_dict: dict of model parameter (torch_model.state_dict)
        nbits: the bit of values after quantized, chosen from [8, 16]

    Returns:
        The quantized model parameters
    """
    if nbits == 8:
        quant_data_type = torch.int8
    elif nbits == 16:
        quant_data_type = torch.int16
    else:
        logger.info(f'The provided value of nbits ({nbits}) is invalid, and we'
                    f' change it to 8')
        nbits = 8
        quant_data_type = torch.int8

    quant_state_dict = dict()
    for key, value in state_dict.items():
        if ('fc' in key or 'conv' in key) and 'weight' == key.split('.')[-1]:
            q_weight, w_s = _symmetric_uniform_quantization(value, nbits=nbits)
            quant_state_dict[key.replace(
                'weight', 'weight_quant')] = q_weight.type(quant_data_type)
            quant_state_dict[key.replace('weight', 'weight_scale')] = w_s
        else:
            quant_state_dict[key] = value

    return quant_state_dict


def symmetric_uniform_dequantization(state_dict):
    """
    Perform symmetric uniform dequantization
    Args:
        state_dict: dict of model parameter (torch_model.state_dict)

    Returns:
        The model parameters after dequantization
    """
    dequantizated_state_dict = dict()
    for key, value in state_dict.items():
        if 'weight_quant' in key:
            alpha = state_dict[key.replace('weight_quant', 'weight_scale')]
            dequantizated_state_dict[key.replace('weight_quant',
                                                 'weight')] = value * alpha
        elif 'weight_scale' in key:
            pass
        else:
            dequantizated_state_dict[key] = value

    return dequantizated_state_dict
