import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import quant_cuda
import time
from quant_affine import *

logger = logging.getLogger(__name__)
class Quantizer(nn.Module):
    def __init__(self, mode="base", bit=8, is_signed=True, is_enable=False, is_input=False, args=None, operator=None):
        super(Quantizer, self).__init__()
        self.mode = mode
        self.register_buffer('bit', torch.tensor(1))
        self.bit = torch.tensor(bit)
        self.is_signed = is_signed
        self.is_enable = is_enable
        self.is_enable_input = is_enable
        self.is_input = is_input
        self.args = args
        self.operator = operator

        self.percent = self.args.percent / 100
        self.is_sigma = False
        if args.sigma > 0:
            self.percent = args.sigma / 10
            self.is_sigma = True
        
        self.cr = args.constraint_radius
        self.pe = args.packed_element

        self.name = None
        self.has_zero = True
        self.quant_value = None
        self.quant_weight_tensor = None
        self.register_buffer('x_max', torch.tensor(1.0))
        self.register_buffer('x_min', torch.tensor(1.0))
        self.has_inited_quant_para = False

        self.squant_k = True
        self.squant_c = True

        self.is_perchannel = True
        if is_input:
            # Input shouldn't be per-channel quantizaton！
            self.is_perchannel = False

        self.tensor_sum = None
        self.tensor_sum_cov = None
        

    def disable_input_quantization(self):
        self.is_enable_input = False
        
    def enable_quantization(self, name):
        self.name = name
        self.is_enable = True

    def disable_quantization(self, name):
        self.name = name
        self.is_enable = False

    def _sigma(self, tensor):
        if not self.is_signed:
            return tensor[tensor > 0].std()
        return tensor.std()

    def updata_packed_element(self, tensor):
        s = tensor.view(tensor.shape[0], tensor.shape[1], -1).shape
        if s[-1] % self.pe != 0:
            self.pe = int(self.pe / 2)

    def updata_signed(self, tensor):
        if tensor.min() < 0:
            self.is_signed = True
            if self.is_input:
                print("Warning!: Signed input!")

    def convert_tensor(self, values):
        values = torch.Tensor(list(set(values)))
        values, _ = torch.sort(values)
        return values

    def int_value(self, q_type="narrow"):
        bit_width = self.bit.item()
        B = bit_width
        if self.is_signed:
            B = bit_width - 1
        values = []
        if self.has_zero:
            values.append(0.)
        else:
            values.append(0.5)

        for i in range(1, 2 ** B):
            values.append(i)
            if self.is_signed:
                values.append(-i)
        if q_type == "int" and self.is_signed:
            # [-128, ..., 0, ..., 127]
            values.append(- 2 ** B)

        assert(2 ** bit_width >= len(values))
        return self.convert_tensor(values)

    def _quantization(self, tensor, quant_value):
        shape = tensor.shape
        quant_tensor = tensor.view(-1)
        quant_value = quant_value.type_as(quant_tensor)
        quant_tensor, quant_idx = quant_cuda.quant(quant_tensor, quant_value)
        quant_tensor = quant_tensor.view(shape)
        quant_idx    = quant_idx.view(shape).type(torch.long)
        return quant_tensor, quant_idx
    
    def adaptive_round(self, x, t_max = None, t_min = None):
        # Get the rounding integer and fraction.
        rounding_number, rounding_idx = self._quantization(torch.clamp(x, torch.min(self.quant_value), torch.max(self.quant_value)), self.quant_value)
        rounding_error  = rounding_number - x

        if t_max is None:
            t_max = torch.max(self.quant_value)
        if t_min is None:
            t_min = torch.min(self.quant_value)
            
        up_number = rounding_number.clone()
        up_error  = rounding_error.clone()
        up_error[x >= t_max]  = 0.0
        up_error[up_error > 0]  = 0.0
        up_priority = up_error.clone().abs()

        up_error[up_error != 0]  += 1
        up_number[up_error != 0] += 1

        down_number = rounding_number.clone()
        down_error  = rounding_error.clone()
        down_error[x <= t_min]  = 0.0
        down_error[down_error < 0]  = 0.0
        down_priority = down_error.clone().abs()

        down_error[down_error != 0]  -= 1
        down_number[down_error != 0] -= 1

        flip_number = torch.tensor([0.0], device=x.device)
        flip_up_number = torch.tensor([0.0], device=x.device)
        flip_down_number = torch.tensor([0.0], device=x.device)

        conver_shape = x.view(x.shape[0], x.shape[1], -1).shape
        if conver_shape[2] == 1:
            self.squant_k = False

        if self.squant_k:
            rounding_error_sum = rounding_error.view(conver_shape).sum(-1)
            _, up_order = torch.sort(up_priority.view(conver_shape), descending=True)
            _, down_order = torch.sort(down_priority.view(conver_shape), descending=True)
            up_priority *= 0.0
            down_priority *= 0.0

            quant_cuda.rounding_loop(
                flip_number,
                flip_up_number,
                flip_down_number,
                
                rounding_error_sum,
                rounding_number.view(conver_shape), 
                rounding_error.view(conver_shape), 

                up_number.view(conver_shape), 
                up_error.view(conver_shape), 
                up_priority.view(conver_shape), 
                up_order.type_as(up_priority).view(conver_shape), 

                down_number.view(conver_shape), 
                down_error.view(conver_shape), 
                down_priority.view(conver_shape),
                down_order.type_as(down_priority).view(conver_shape),
            )
        
        if self.squant_c:
            conver_shape = (1, x.shape[0], -1)
            rounding_error_sum = rounding_error.view(conver_shape).sum(-1)
            _, up_order = torch.sort(up_priority.view(conver_shape), descending=True)
            _, down_order = torch.sort(down_priority.view(conver_shape), descending=True)

            quant_cuda.rounding_loop(
                flip_number,
                flip_up_number,
                flip_down_number,
                
                rounding_error_sum,
                rounding_number.view(conver_shape), 
                rounding_error.view(conver_shape), 

                up_number.view(conver_shape), 
                up_error.view(conver_shape), 
                up_priority.view(conver_shape), 
                up_order.type_as(up_priority).view(conver_shape), 

                down_number.view(conver_shape), 
                down_error.view(conver_shape), 
                down_priority.view(conver_shape),
                down_order.type_as(down_priority).view(conver_shape)
            )

        assert (rounding_number.unique().numel() <= 2 ** self.bit.item())
        return rounding_number

    @torch.no_grad()
    def _init_quant_para(self, data):
        if self.has_inited_quant_para == False:
            print("QUANT %d bit: %s " %(self.bit.item(), self.name))
            self.updata_packed_element(data)
            self.updata_signed(data)

            x_max = data.max()
            x_min = data.min()
            if "int" in self.mode:
                x_max = data.abs().max()

            alpha = self.percent * data.abs().max()
            if self.is_sigma:
                sigma = self._sigma(data) * 10
                alpha = self.percent * sigma
                if self.is_signed:
                    # We also consider the signed activation.
                    alpha = self.percent * sigma / 1.25

                # For a higher bit-width, using a wider range still will not cause accuracy loss.
                if self.bit < 6:
                    # For small bit, need clip.
                    alpha = min(alpha, x_max)

                x_max = 10 * sigma

            if self.is_perchannel:
                x_max, _ = data.view(data.shape[0], -1).max(1)
                x_max = x_max.unsqueeze(1)
                x_min, _ = data.view(data.shape[0], -1).min(1)
                x_min = x_min.unsqueeze(1)

                alpha, _ = data.view(data.shape[0], -1).abs().max(1)
                alpha = alpha.unsqueeze(1)
                alpha = self.percent * alpha
                
            if self.mode == "squant-e":
                self.squant_k = False
                self.squant_c = False
                self.mode = "squant"
            elif self.mode == "squant-k":
                self.squant_c = False
                self.mode = "squant"
            elif self.mode == "squant-c":
                self.squant_k = False
                self.mode = "squant"
            
            if self.mode == "squant":
                def _quant(tensor, alpha):
                    if self.is_perchannel:
                        x_max = data.view(data.shape[0], -1).max(1).values
                        x_max = x_max.unsqueeze(1)
                        x_min = data.view(data.shape[0], -1).min(1).values
                        x_min = x_min.unsqueeze(1)
                    else:
                        x_max = data.max()
                        x_min = data.min()
                        
                    scale, zero_point = asymmetric_linear_quantization_params(self.bit, x_min, x_max)
                    quant_tensor = linear_quantize(data, scale, zero_point, inplace=False)

                    if self.mode == "squant":
                        quant_tensor = self.adaptive_round(quant_tensor)

                    n = 2 ** (self.bit - 1)
                    quant_tensor = torch.clamp(quant_tensor, -n, n - 1)
                    quant_tensor = linear_dequantize(quant_tensor, scale, zero_point, inplace=False)
                    return quant_tensor

                self.quant_value = self.int_value(q_type="int").cuda()
                if not self.is_input:
                    start = time.perf_counter()
                    self.quant_weight_tensor = _quant(data, alpha)
                    elapsed = (time.perf_counter() - start)
                    print("Quantzation time: %f ms" %(elapsed * 1000))
                else:
                    if self.is_signed:
                        self.x_min = alpha / self.quant_value.max() * self.quant_value.min()
                        # self.x_min = alpha / data.min()
                    else:
                        self.x_min.data = torch.zeros_like(alpha)
                    self.x_max.data = alpha                    
            else:
                raise RuntimeError("Unsupported mode: " + self.mode)
                
        self.has_inited_quant_para = True
        
    def _forward(self, data):
        tensor = AsymmetricQuantFunction.apply(data, self.bit, self.x_min, self.x_max)
        return tensor
    
    def tensor_forward(self, tensor, image_size):
        self.image_size = image_size
        if self.mode == "base":
            return tensor
        if not self.is_enable:
            return tensor
        if self.is_input:
            if not self.is_enable_input:
                return tensor

        with torch.no_grad():

            self._init_quant_para(tensor)        
            if self.is_input:                               
                return self._forward(tensor)
            else:
                return self.quant_weight_tensor

class TensorQuantizer(Quantizer):
    def __init__(self, **kwargs):
        super(TensorQuantizer, self).__init__(**kwargs)

    def forward(self, tensor, image_size = 0):
        return self.tensor_forward(tensor, image_size)

class ActivationQuantizer(nn.Module):
    def __init__(self, mode=None, wbit=None, abit=None, args=None):
        super(ActivationQuantizer, self).__init__()        
        assert mode is not None,'Quantizer is not initilized!'
        self.quant_output  = TensorQuantizer(mode=mode, bit=abit, is_signed=False, is_enable=True, args=args, is_input=True)

    def forward(self, output):
        return self.quant_output(output)

class LinearQuantizer(nn.Module):
    """
    Class to quantize given linear layer weights
    """
    def __init__(self, mode=None, wbit=None, abit=None, args=None):
        """
        weight: bit-setting for weight
        full_precision_flag: full precision or not
        running_stat: determines whether the activation range is updated or froze
        """
        super(LinearQuantizer, self).__init__()
        assert mode is not None,'Quantizer is not initilized!'
        self.quant_input  = TensorQuantizer(mode=mode, bit=abit, is_signed=False, is_enable=True, args=args, is_input=True)
        self.quant_weight = TensorQuantizer(mode=mode, bit=wbit, is_signed=True, is_enable=True, args=args)

    def set_param(self, linear):
        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.weight = nn.Parameter(linear.weight.data.clone())
        try:
            self.bias = nn.Parameter(linear.bias.data.clone())
        except AttributeError:
            self.bias = None

    def forward(self, input): 
        input = self.quant_input(input)
        weight = self.quant_weight(self.weight)
        # print(input.unique().numel(), self.quant_input.name)
        return F.linear(input, weight, self.bias)


class Conv2dQuantizer(nn.Module):
    """
    Class to quantize given convolutional layer weights
    """
    def __init__(self, mode=None, wbit=None, abit=None, args=None):
        super(Conv2dQuantizer, self).__init__()
        assert mode is not None,'Quantizer is not initilized!'
        self.quant_input  = TensorQuantizer(mode=mode, bit=abit, is_signed=False, is_enable=True, args=args, is_input=True)
        self.quant_weight = TensorQuantizer(mode=mode, bit=wbit, is_signed=True, is_enable=True, args=args)


    def set_param(self, conv):
        self.in_channels = conv.in_channels
        self.out_channels = conv.out_channels
        self.kernel_size = conv.kernel_size
        self.stride = conv.stride
        self.padding = conv.padding
        self.dilation = conv.dilation
        self.groups = conv.groups
        self.weight = nn.Parameter(conv.weight.data.clone())
        try:
            self.bias = nn.Parameter(conv.bias.data.clone())
        except AttributeError:
            self.bias = None

    def _conv_forward(self, input, weight):
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        input = self.quant_input(input)       
        weight = self.quant_weight(self.weight) 
        # print(input.unique().numel(), self.quant_input.name, "input")
        return self._conv_forward(input, weight)