from collections import namedtuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import InplaceFunction, Function

import numpy as np
import scipy.io as sio

QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits', 'max_value'])

_DEFAULT_FLATTEN = (1, -1)
_DEFAULT_FLATTEN_GRAD = (0, -1)


def _deflatten_as(x, x_full):
    shape = list(x.shape) + [1] * (x_full.dim() - x.dim())
    return x.view(*shape)

 
def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False,
                      true_zero=False):
    with torch.no_grad():
        x_flat = x.flatten(*flatten_dims)
        if x_flat.dim() == 1:
            min_values = _deflatten_as(x_flat.min(), x)
            max_values = _deflatten_as(x_flat.max(), x)
        else:
            min_values = _deflatten_as(x_flat.min(-1)[0], x)
            max_values = _deflatten_as(x_flat.max(-1)[0], x)

        if reduce_dim is not None:
            if reduce_type == 'mean':
                min_values = min_values.mean(reduce_dim, keepdim=keepdim)
                max_values = max_values.mean(reduce_dim, keepdim=keepdim)
            else:
                min_values = min_values.min(reduce_dim, keepdim=keepdim)[0]
                max_values = max_values.max(reduce_dim, keepdim=keepdim)[0]

        range_values = max_values - min_values
        return QParams(range=range_values, zero_point=min_values, max_value=max_values,
                       num_bits=num_bits)


def calculate_qparams_permute(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False,
                      true_zero=False):
    with torch.no_grad():
        x = x.permute(1, 0, 2, 3)
        x_flat = x.flatten(*flatten_dims)
        if x_flat.dim() == 1:
            min_values = _deflatten_as(x_flat.min(), x)
            max_values = _deflatten_as(x_flat.max(), x)
        else:
            min_values = _deflatten_as(x_flat.min(-1)[0], x)
            max_values = _deflatten_as(x_flat.max(-1)[0], x)

        if reduce_dim is not None:
            if reduce_type == 'mean':
                min_values = min_values.mean(reduce_dim, keepdim=keepdim)
                max_values = max_values.mean(reduce_dim, keepdim=keepdim)
            else:
                min_values = min_values.min(reduce_dim, keepdim=keepdim)[0]
                max_values = max_values.max(reduce_dim, keepdim=keepdim)[0]

        range_values = max_values - min_values
        return QParams(range=range_values.permute(1, 0, 2, 3), zero_point=min_values.permute(1, 0, 2, 3), max_value=max_values.permute(1, 0, 2, 3),
                       num_bits=num_bits)


class UniformQuantize(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
                reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)

        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizePrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
                reduce_dim=None, dequantize=True, signed=False, stochastic=False, inplace=False, prun_factor=0.2, 
                index='', iter=None):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)


        # per-sample pruning 
        # pruning_mask = torch.tensor(1.).expand_as(qparams.range).cuda()
        # range_max = qparams.range.max()
        # pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
        #                            torch.tensor(0.).expand_as(qparams.range).cuda(), pruning_mask)
            
        # stochastic per-sample pruning
        # pruning_mask = torch.tensor(1.).expand_as(qparams.range).cuda()
        # range_max = qparams.range.max()
        # pruning_mask = torch.where(qparams.range + qparams.range * qparams.range.new(output.shape).uniform_(-prun_factor/2, prun_factor/2) < prun_factor * range_max, 
        #                            torch.tensor(0.).expand_as(qparams.range).cuda(), pruning_mask)
            
        # new per-sample pruning
        # range_max = qparams.range.max()
        # pruning_mask = torch.where(qparams.range < prun_factor * range_max, torch.tensor(0.).expand_as(qparams.range).cuda(), 
        #                             torch.tensor(1.).expand_as(qparams.range).cuda())
        # output.mul_(pruning_mask)
        # pruning_num = output.shape[0] - torch.sum(pruning_mask)
        # tensor_num = output.shape[0]
        # rate = pruning_num / tensor_num
            
        # per-sampe pruning
        # if len(output.shape) == 2:
        #     range_max = qparams.range.max()
        #     pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
        #                                 torch.tensor(0.).expand_as(qparams.range).cuda(), torch.tensor(1.).expand_as(qparams.range).cuda())
        #     num_samples, num_channels = pruning_mask.shape
        #     pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
        #     tensor_num = num_samples * num_channels
        #     output.mul_(pruning_mask)
        # else:
        #     x_flat = output.flatten(*(2, -1))
        #     min_values = x_flat.min(-1).values
        #     max_values = x_flat.max(-1).values
        #     range_values = max_values - min_values
        #     range_max = range_values.max()
        #     pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
        #                     torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
        #                     torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
        #     pruning_mask = pruning_mask[:, :, None, None]
        #     num_samples, num_channels = pruning_mask.shape[0:2]
        #     pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
        #     tensor_num = num_samples * num_channels
        #     output.mul_(pruning_mask)

        # per-channel pruning
        if len(output.shape) == 2:
            range_max = qparams.range.max()
            pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
                                        torch.tensor(0.).expand_as(qparams.range).cuda(), torch.tensor(1.).expand_as(qparams.range).cuda())
            num_samples, num_channels = pruning_mask.shape
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            output.mul_(pruning_mask)
        else:
            input = input.permute(1, 0, 2, 3)
            output = output.permute(1, 0, 2, 3)
            qparams = calculate_qparams(
                input, num_bits=qparams.num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)
            x_flat = output.flatten(*(2, -1))
            min_values = x_flat.min(-1).values
            max_values = x_flat.max(-1).values
            range_values = max_values - min_values
            range_max = range_values.max()
            pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
                            torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
                            torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
            pruning_mask = pruning_mask[:, :, None, None]
            num_samples, num_channels = pruning_mask.shape[0:2]
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            output.mul_(pruning_mask)
        if len(input.shape) != 2:
            output = output.permute(1, 0, 2, 3)
            input = input.permute(1, 0, 2, 3)

        rate = pruning_num / tensor_num

        if rate < 0.33:
            num_bits = qparams.num_bits
        elif rate < 0.5:
            num_bits = qparams.num_bits + qparams.num_bits // 2
        elif rate < 0.66:
            num_bits = qparams.num_bits * 2
        elif rate < 0.75:
            num_bits = qparams.num_bits * 3
        else:
            num_bits = qparams.num_bits * 4

        file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/' + index + '.txt', "a")
        file.write("{} {} {}\n".format(pruning_num, tensor_num, num_bits))
        file.close()

        file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/' + index + '_range.txt', "a")
        temp = qparams.range.to(device='cpu').numpy()
        if len(output.shape) == 2:
            file.write(str(temp[:, 0]))
        else:
            file.write(str(temp[:, 0, 0, 0]))
        file.write('\n')
        file.close()

        # if iter % 10000 == 1:
        #     np.save('./cifar_10_tuning/UQ_piecewise_78000_INTV4_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/npy/' + index + '_' + str(iter) + '_befor_prun.npy', input.detach().to(device='cpu').numpy())
        #     np.save('./cifar_10_tuning/UQ_piecewise_78000_INTV4_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/npy/' + index + '_' + str(iter) + '_mask.npy', pruning_mask.detach().to(device='cpu').numpy())
        
        # channel dim scale
        # if len(input.shape) != 2:
        #     range_channel = input.max(0).values.max(1).values.max(1).values - input.min(0).values.min(1).values.min(1).values
        #     range_channel[range_channel == 0.] = 1.
        #     range_scale = range_channel.max() / range_channel
        #     range_scale = range_scale[None, :, None, None].repeat(input.shape[0], 1, input.shape[2], input.shape[3])
        #     output = output * range_scale

        qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)
        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize
                
            # channel dim re-scale  
            # if len(input.shape) != 2:
            #     output = output / range_scale

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizeSCPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None,
                reduce_dim=None, dequantize=True, signed=False, stochastic=False, inplace=False, prun_factor=0.2, 
                index='', iter=None):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(2, 1), reduce_dim=reduce_dim)

        # per-channel pruning
        if len(output.shape) == 2:
            range_max = qparams.range.max()
            pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
                                        torch.tensor(0.).expand_as(qparams.range).cuda(), torch.tensor(1.).expand_as(qparams.range).cuda())
            num_samples, num_channels = pruning_mask.shape
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            # output.mul_(pruning_mask)
        else:
            input = input.permute(1, 0, 2, 3)
            output = output.permute(1, 0, 2, 3)
            qparams = calculate_qparams(
                input, num_bits=qparams.num_bits, flatten_dims=(1, -1), reduce_dim=reduce_dim)
            x_flat = output.flatten(*(2, -1))
            min_values = x_flat.min(-1).values
            max_values = x_flat.max(-1).values
            range_values = max_values - min_values
            range_max = range_values.max()
            pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
                            torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
                            torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
            pruning_mask = pruning_mask[:, :, None, None]
            num_samples, num_channels = pruning_mask.shape[0:2]
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            # output.mul_(pruning_mask)
        if len(input.shape) != 2:
            output = output.permute(1, 0, 2, 3)
            input = input.permute(1, 0, 2, 3)

        rate = pruning_num / tensor_num

        # if rate < 0.33:
        #     num_bits = qparams.num_bits
        # elif rate < 0.5:
        #     num_bits = qparams.num_bits + qparams.num_bits // 2
        # elif rate < 0.66:
        #     num_bits = qparams.num_bits * 2
        # elif rate < 0.75:
        #     num_bits = qparams.num_bits * 3
        # else:
        #     num_bits = qparams.num_bits * 4

        num_bits = qparams.num_bits

        # file = open('./cifar_100_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/' + index + '.txt', "a")
        # file.write("{} {} {}\n".format(pruning_num, tensor_num, num_bits))
        # file.close()

        # file = open('./cifar_100_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/' + index + '_range.txt', "a")
        # temp = qparams.range.to(device='cpu').numpy()
        # if len(output.shape) == 2:
        #     file.write(str(temp[:, 0]))
        # else:
        #     file.write(str(temp[:, 0, 0, 0]))
        # file.write('\n')
        # file.close()

        if iter % 100 == 0:
            temp = qparams.range.to(device='cpu').numpy()
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_noendmask_GAaroundBN4bit/ga/channel_' + index + '_' + str(iter) + '.mat', 
                        {'data': temp[:, 0, 0, 0]})
        
        # record scale of output
        # if len(output.shape) == 4:
        #     if index == '1_2_1' or index == '2_2_1' or index == '3_2_1':
        #         file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/ga/' + index + '_channel.txt', "a")
        #         temp = output.permute(1, 0, 2, 3)
        #         channel_range = temp.flatten(1, -1).max(-1).values - temp.flatten(1, -1).min(-1).values
        #         for i in range(output.shape[0]):
        #             file.write("{} ".format(channel_range[i]))
        #         file.write("\n")
        #         file.close()
            
        #     file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/ga/' + index + '_sample_channel.txt', "a")
        #     sample_channel_range = output.flatten(2, -1).max(2).values - output.flatten(2, -1).min(2).values
        #     for i in range(output.shape[0]):
        #         for j in range(output.shape[1]):
        #             file.write("{} ".format(sample_channel_range[i][j]))
        #     file.write("\n")
        #     file.close()

        # if iter % 10000 == 1:
        #     np.save('./cifar_10_tuning/UQ_piecewise_78000_INTV4_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/npy/' + index + '_' + str(iter) + '_befor_prun.npy', input.detach().to(device='cpu').numpy())
        #     np.save('./cifar_10_tuning/UQ_piecewise_78000_INTV4_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun/npy/' + index + '_' + str(iter) + '_mask.npy', pruning_mask.detach().to(device='cpu').numpy())
        
        # channel dim scale
        # if len(input.shape) != 2:
        #     range_channel = input.max(0).values.max(1).values.max(1).values - input.min(0).values.min(1).values.min(1).values
        #     range_channel[range_channel == 0.] = 1.
        #     range_scale = range_channel.max() / range_channel
        #     range_scale = range_scale[None, :, None, None].repeat(input.shape[0], 1, input.shape[2], input.shape[3])
        #     output = output * range_scale

        qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(2, -1), reduce_dim=reduce_dim)
        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        if iter % 100 == 0:
            temp = qparams.range.to(device='cpu').numpy()
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_noendmask_GAaroundBN4bit/ga/sc_' + index + '_' + str(iter) + '.mat', 
                        {'data': temp[:, :, 0, 0]})

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize
                
            # channel dim re-scale  
            # if len(input.shape) != 2:
            #     output = output / range_scale

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None


class UniformQuantizeSCSPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None,
                reduce_dim=None, dequantize=True, signed=False, stochastic=False, inplace=False, prun_factor=0.2, sw=None,
                index='', iter=None):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(2, 1), reduce_dim=reduce_dim)

        # per-channel pruning
        if len(output.shape) == 2:
            range_max = qparams.range.max()
            pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
                                        torch.tensor(0.).expand_as(qparams.range).cuda(), torch.tensor(1.).expand_as(qparams.range).cuda())
            num_samples, num_channels = pruning_mask.shape
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            # output.mul_(pruning_mask)
        else:
            input = input.permute(1, 0, 2, 3)
            output = output.permute(1, 0, 2, 3)
            qparams = calculate_qparams(
                input, num_bits=qparams.num_bits, flatten_dims=(1, -1), reduce_dim=reduce_dim)
            x_flat = output.flatten(*(2, -1))
            min_values = x_flat.min(-1).values
            max_values = x_flat.max(-1).values
            range_values = max_values - min_values
            range_max = range_values.max()
            pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
                            torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
                            torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
            pruning_mask = pruning_mask[:, :, None, None]
            num_samples, num_channels = pruning_mask.shape[0:2]
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            # output.mul_(pruning_mask)
        if len(input.shape) != 2:
            output = output.permute(1, 0, 2, 3)
            input = input.permute(1, 0, 2, 3)

        rate = pruning_num / tensor_num

        num_bits = qparams.num_bits

        qparams_ = calculate_qparams(
            input, num_bits=num_bits, flatten_dims=(2, -1), reduce_dim=reduce_dim)

        # smooth
        sw = sw[:, 0, 0, 0]
        sw = sw / torch.max(sw, -1).values + 1e-4
        sw = sw[None, :] + 1e-8 * torch.max(sw, -1).values
        sw = sw.repeat([qparams_.range.shape[0], 1])
        smooth_scale = sw / (qparams_.range[:, :, 0, 0] / torch.max(qparams_.range[:, :, 0, 0], -1).values[:, None].repeat([1, qparams_.range.shape[1]]) + 1e-4)
        # input = input * smooth_scale[:, :, None, None] ** (0.5)

        if iter % 1000 == 0:
            temp = qparams.range.to(device='cpu').numpy()
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformpers_nonDSQ_noendmask_sslog/scale_ga/' + index + '_' + str(iter) + '.mat', 
                        {'data': temp})
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformpers_nonDSQ_noendmask_sslog/scale_smooth/' + index + '_' + str(iter) + '.mat', 
                        {'data': smooth_scale.to(device='cpu').numpy()})
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformpers_nonDSQ_noendmask_sslog/scale_w/' + index + '_' + str(iter) + '.mat', 
                        {'data': sw.to(device='cpu').numpy()})

        qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=reduce_dim)
        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize
                
            # re-smooth 
            # input = input * smooth_scale[:, :, None, None] ** (- 0.5)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizeSCDPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None,
                reduce_dim=None, dequantize=True, signed=False, stochastic=False, inplace=False, prun_factor=0.2, sw=None,
                index='', iter=None):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(2, 1), reduce_dim=reduce_dim)

        # per-channel pruning
        if len(output.shape) == 2:
            range_max = qparams.range.max()
            pruning_mask = torch.where(qparams.range < prun_factor * range_max, 
                                        torch.tensor(0.).expand_as(qparams.range).cuda(), torch.tensor(1.).expand_as(qparams.range).cuda())
            num_samples, num_channels = pruning_mask.shape
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            output.mul_(pruning_mask)
        else:
            input = input.permute(1, 0, 2, 3)
            output = output.permute(1, 0, 2, 3)
            qparams = calculate_qparams(
                input, num_bits=qparams.num_bits, flatten_dims=(1, -1), reduce_dim=reduce_dim)
            x_flat = output.flatten(*(2, -1))
            min_values = x_flat.min(-1).values
            max_values = x_flat.max(-1).values
            range_values = max_values - min_values
            range_max = range_values.max()
            pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
                            torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
                            torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
            pruning_mask = pruning_mask[:, :, None, None]
            num_samples, num_channels = pruning_mask.shape[0:2]
            pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
            tensor_num = num_samples * num_channels
            output.mul_(pruning_mask)
            output = output.permute(1, 0, 2, 3)
            input = input.permute(1, 0, 2, 3)

        rate = pruning_num / tensor_num

        num_bits = qparams.num_bits

        qparams = calculate_qparams(
            input, num_bits=num_bits, flatten_dims=(2, -1), reduce_dim=reduce_dim)

        # smooth
        # sw = sw[:, 0, 0, 0]
        # sw = sw[None, :]
        # sw = sw.repeat([qparams.range.shape[0], 1])
        # scale_gw = sw * qparams.range[:, :, 0, 0]
        # smooth = sw ** 0.5 / qparams.range[:, :, 0, 0]
        # desmooth = qparams.range[:, :, 0, 0] / sw ** 0.5
        # input = input * sw[:, :, None, None]

        # if iter % 1000 == 0:
        #     temp = qparams.range.to(device='cpu').numpy()
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_ga/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': temp})
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_gw/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': scale_gw.to(device='cpu').numpy()})
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_w/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': sw.to(device='cpu').numpy()})
        

        # divide channels
        max_values = qparams.max_value
        max_values_g = qparams.max_value
        temp_max = torch.where(max_values >= torch.max(max_values, 1).values[:, :, :, None] / 2, 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), max_values_g)
        temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, :, :, None] / 4) & 
                               (max_values < torch.max(max_values, 1).values[:, :, :, None] / 2), 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 2, temp_max)
        temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, :, :, None] / 8) & 
                               (max_values < torch.max(max_values, 1).values[:, :, :, None] / 4), 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 4, temp_max)
        temp_max = torch.where(max_values < torch.max(max_values, 1).values[:, :, :, None] / 8, 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 8, temp_max)
        
        state = max_values >= torch.max(max_values, 1).values[:, :, :, None] / 2
        min_values_g = qparams.zero_point
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    min_zero_point)
        
        state = (max_values >= torch.max(max_values, 1).values[:, :, :, None] / 4) & (max_values < torch.max(max_values, 1).values[:, :, :, None] / 2)
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        state = (max_values >= torch.max(max_values, 1).values[:, :, :, None] / 8) & (max_values < torch.max(max_values, 1).values[:, :, :, None] / 4)
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        state = max_values < torch.max(max_values, 1).values[:, :, :, None] / 8
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state,
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        temp_range = temp_max - temp_zero_point

        # sio.savemat('range_before' + index + '_' + str(iter), {'data': qparams.range.to(device='cpu').numpy()})
        # sio.savemat('range_after' + index + '_' + str(iter), {'data': temp_range.to(device='cpu').numpy()})
        # sio.savemat('zp_before' + index + '_' + str(iter), {'data': qparams.zero_point.to(device='cpu').numpy()})
        # sio.savemat('zp_after' + index + '_' + str(iter), {'data': temp_zero_point.to(device='cpu').numpy()})

        # temp_zero_point = qparams.zero_point
        # temp_range = qparams.range

        zero_point = temp_zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = temp_range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize
            # desmooths
            # input = input / sw[:, :, None, None]

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizeSCDPrunPercent(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None,
                reduce_dim=None, dequantize=True, signed=False, stochastic=False, inplace=False, prun_factor=0.5, sw=None,
                index='', iter=None):

        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=(2, 1), reduce_dim=reduce_dim)

        # per-channel pruning
        input = input.permute(1, 0, 2, 3)
        output = output.permute(1, 0, 2, 3)
        C, N, H, W = output.shape
        qparams = calculate_qparams(
            input, num_bits=qparams.num_bits, flatten_dims=(1, -1), reduce_dim=reduce_dim)
        scale_sum = qparams.range.sum()
        print(scale_sum)
        p = qparams.range * C * N / (2 * scale_sum)
        P = p / p.max()
        print(p[:, 0, 0, 0])
        pruning_mask = torch.bernoulli(p)
        print(pruning_mask)

        # x_flat = output.flatten(*(2, -1))
        # min_values = x_flat.min(-1).values
        # max_values = x_flat.max(-1).values
        # range_values = max_values - min_values
        # range_max = range_values.max()
        # pruning_mask = torch.where(qparams.range[:, :, 0, 0].repeat(1, input.shape[1]) < prun_factor * range_max, 
        #                 torch.tensor(0.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda(), 
        #                 torch.tensor(1.).expand_as(qparams.range[:, :, 0, 0].repeat(1, input.shape[1])).cuda())
        # pruning_mask = pruning_mask[:, :, None, None]
        num_samples, num_channels = pruning_mask.shape[0:2]
        pruning_num = num_samples * num_channels - torch.sum(pruning_mask)
        tensor_num = num_samples * num_channels
        output.mul_(pruning_mask)

        output = output.permute(1, 0, 2, 3)
        input = input.permute(1, 0, 2, 3)

        rate = pruning_num / tensor_num

        num_bits = qparams.num_bits

        qparams = calculate_qparams(
            input, num_bits=num_bits, flatten_dims=(2, -1), reduce_dim=reduce_dim)

        # smooth
        # sw = sw[:, 0, 0, 0]
        # sw = sw[None, :]
        # sw = sw.repeat([qparams.range.shape[0], 1])
        # scale_gw = sw * qparams.range[:, :, 0, 0]
        # smooth = sw ** 0.5 / qparams.range[:, :, 0, 0]
        # desmooth = qparams.range[:, :, 0, 0] / sw ** 0.5
        # input = input * sw[:, :, None, None]

        # if iter % 1000 == 0:
        #     temp = qparams.range.to(device='cpu').numpy()
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_ga/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': temp})
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_gw/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': scale_gw.to(device='cpu').numpy()})
        #     sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_uniformperscd_nonDSQ_noendmask_sslog/scale_w/' + index + '_' + str(iter) + '.mat', 
        #                 {'data': sw.to(device='cpu').numpy()})
        

        # divide channels
        max_values = qparams.max_value
        max_values_g = qparams.max_value
        temp_max = torch.where(max_values >= torch.max(max_values, 1).values[:, :, :, None] / 2, 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), max_values_g)
        temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, :, :, None] / 4) & 
                               (max_values < torch.max(max_values, 1).values[:, :, :, None] / 2), 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 2, temp_max)
        temp_max = torch.where((max_values >= torch.max(max_values, 1).values[:, :, :, None] / 8) & 
                               (max_values < torch.max(max_values, 1).values[:, :, :, None] / 4), 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 4, temp_max)
        temp_max = torch.where(max_values < torch.max(max_values, 1).values[:, :, :, None] / 8, 
                                    torch.max(max_values_g, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]) / 8, temp_max)
        
        state = max_values >= torch.max(max_values, 1).values[:, :, :, None] / 2
        min_values_g = qparams.zero_point
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    min_zero_point)
        
        state = (max_values >= torch.max(max_values, 1).values[:, :, :, None] / 4) & (max_values < torch.max(max_values, 1).values[:, :, :, None] / 2)
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        state = (max_values >= torch.max(max_values, 1).values[:, :, :, None] / 8) & (max_values < torch.max(max_values, 1).values[:, :, :, None] / 4)
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state, 
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        state = max_values < torch.max(max_values, 1).values[:, :, :, None] / 8
        min_zero_point = torch.where(state, min_values_g, torch.tensor(1e5).expand_as(state).cuda())
        temp_zero_point = torch.where(state,
                                    torch.min(min_zero_point, 1).values[:, :, :, None].repeat([1, sw.shape[0], 1, 1]), 
                                    temp_zero_point)
        
        temp_range = temp_max - temp_zero_point

        # sio.savemat('range_before' + index + '_' + str(iter), {'data': qparams.range.to(device='cpu').numpy()})
        # sio.savemat('range_after' + index + '_' + str(iter), {'data': temp_range.to(device='cpu').numpy()})
        # sio.savemat('zp_before' + index + '_' + str(iter), {'data': qparams.zero_point.to(device='cpu').numpy()})
        # sio.savemat('zp_after' + index + '_' + str(iter), {'data': temp_zero_point.to(device='cpu').numpy()})

        # temp_zero_point = qparams.zero_point
        # temp_range = qparams.range

        zero_point = temp_zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = temp_range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize
            # desmooths
            # input = input / sw[:, :, None, None]

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None


class UniformQuantizeBp(nn.Module):

    def __init__(self):
        super().__init__()

    # @staticmethod
    # def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
    #             reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):

    #     ctx.inplace = inplace

    #     if ctx.inplace:
    #         ctx.mark_dirty(input)
    #         output = input
    #     else:
    #         output = input.clone()

    def forward(self, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
                reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
        
        output = input.clone()
        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)

        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)

        with torch.no_grad():
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize

        return output

    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None


class UniformQuantizeW(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
                reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False, weight_gap_bits=2, index=None, momentum=0.9, iter=0):
        ctx.num_bits = num_bits
        ctx.flatten_dims = flatten_dims
        ctx.reduce_dim = reduce_dim
        ctx.dequantize = dequantize
        ctx.signed = signed
        ctx.stochastic = stochastic
        ctx.index = index
        ctx.momentum = momentum
        ctx.iter = iter

        output = input.clone()
        if qparams is None:
            assert num_bits is not None, "either provide qparams of num_bits to quantize"
            qparams = calculate_qparams(
                input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)

        zero_point = qparams.zero_point
        num_bits = qparams.num_bits
        qmin = -(2. ** (num_bits - 1)) if signed else 0.
        qmax = qmin + 2. ** num_bits - 1.
        scale = qparams.range / (qmax - qmin)

        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)
        ctx.scale_w = scale

        with torch.no_grad():
            # print(output.shape)
            # print(scale.shape)
            # print(zero_point.shape)
            # print('\n')
            output.add_(qmin * scale - zero_point).div_(scale)
            if stochastic:
                noise = output.new(output.shape).uniform_(-0.5, 0.5)
                output.add_(noise)
            # quantize
            output.clamp_(qmin, qmax).round_()

            if dequantize:
                output.mul_(scale).add_(
                    zero_point - qmin * scale)  # dequantize

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # DSQ with approximated gradients strategy
        # grad_output.requires_grad = False
        # ctx.weight_gap_q.requires_grad = False
        # grad_input = torch.mul(grad_output, ctx.weight_gap_q)

        # quantize grad to the same bitwidth with weight
        qparams = calculate_qparams(
                    grad_output, num_bits=2 * ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')
        # grad_input = quantize(grad_output, num_bits=None,
        #                           qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
        #                           dequantize=True, signed=False, stochastic=True, inplace=False)
        
        ###########################################################################
        # 23/12/07
        # # adaptive weight update
        # ctx.num_bits_s = 16
        # ctx.l_shift = 12
        # qmin = -(2. ** (ctx.num_bits - 1)) if ctx.signed else 0.
        # qmax = qmin + 2. ** ctx.num_bits - 1.
        # scale = qparams.range / (qmax - qmin)
        # min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        # scale = torch.max(scale, min_scale)
        # ctx.scale_gw = scale
        # # calculate the scale storaged in integer
        # scale_w_int = cal_scale(ctx.num_bits_s, ctx.l_shift, ctx.scale_w)
        # scale_gw_int = cal_scale(ctx.num_bits_s, ctx.l_shift, ctx.scale_gw)

        # 24/1/29
        # adaptive weight update
        ctx.num_bits_s = 16
        ctx.l_shift = 32
        qmin = -(2. ** (ctx.num_bits - 1)) if ctx.signed else 0.
        qmax = qmin + 2. ** ctx.num_bits - 1.
        scale = qparams.range / (qmax - qmin)
        min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
        scale = torch.max(scale, min_scale)
        ctx.scale_gw = scale
        # calculate the scale storaged in integer
        scale_w_int, l_shift_w = cal_scale_persample(ctx.num_bits_s, ctx.l_shift, ctx.scale_w)
        scale_gw_int, l_shift_gw = cal_scale_persample(ctx.num_bits_s, ctx.l_shift, ctx.scale_gw)

        ###########################################################################
        # update state factor
        ctx.state_factor = torch.log2(scale_gw_int / scale_w_int)
        # if hasattr(ctx, 'state_factor_last') == False:
        #     ctx.state_factor_last = ctx.state_factor
        # if no update or unstable update, tuning the learning rate on this layer (update on 23/12/07)
        # 23/12/07a: scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits - ctx.state_factor)
        # 23/12/07b: scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits - ctx.state_factor + 2)
        # 23/12/08a: scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits - ctx.state_factor + 2.5)
        # 23/12/08b: if torch.min(ctx.state_factor) < - 2 * ctx.num_bits + 1: scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits + 1 - ctx.state_factor + 1)
        # aviod the lower state_factor
        # if torch.min(ctx.state_factor) < - 2 * ctx.num_bits:
        #     scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits - ctx.state_factor + 2)
        #     ctx.state_factor = ctx.state_factor - 2 * ctx.num_bits - ctx.state_factor + 2
        # momentum state_factor tuning 23/12/09a
        # temp = ctx.state_factor
        # ctx.state_factor = ctx.state_factor_last * ctx.momentum + ctx.state_factor * (1 - ctx.momentum)
        # scale_gw_int = scale_gw_int * 2 ** (torch.round(ctx.state_factor - temp))
        # ctx.state_factor = temp + torch.round(ctx.state_factor - temp)

        # momentum state_factor tuning 23/12/09b
        # if torch.min(ctx.state_factor) < - 2 * ctx.num_bits:
        #     scale_gw_int = scale_gw_int * 2 ** (- 2 * ctx.num_bits - ctx.state_factor + 2)
        #     # scale = 2 ** (- 2 * ctx.num_bits - ctx.state_factor + 2)
        #     ctx.state_factor = ctx.state_factor - 2 * ctx.num_bits - ctx.state_factor + 2
        # elif torch.max(ctx.state_factor) > - 2:
        #     scale_gw_int = scale_gw_int * 2 ** (- 3 - ctx.state_factor)
        #     ctx.state_factor = ctx.state_factor - 3 - ctx.state_factor

        # layer-wise state factor tuning 24/1/16
        # refers to LARS: the learning rate is proportional to 
        # the ratio of the norm of the weights to the norm of the gradients
        # in our code, state_factor represents the inverse of the ratio
        # lr_lw = scale_w_int / (scale_gw_int + scale_w_int)

        # before 24/1/16
        # grad_input = ctx.scale_w / scale * torch.round(scale_gw_int * 2 ** ctx.num_bits / scale_w_int) * 2 ** (- ctx.num_bits) * grad_output
        
        # 24/1/16
        # grad_input = lr_lw * ctx.scale_w / scale * torch.round(scale_gw_int * 2 ** ctx.num_bits / scale_w_int) * 2 ** (- ctx.num_bits) * grad_output

        # 24/1/27
        # grad_input = lr_lw * scale_gw_int * grad_output / scale * 2 ** - ctx.l_shift

        # 24/1/29: united with per-sample scale integerization 
        grad_input = scale_gw_int * grad_output / scale * 2 ** - l_shift_gw

        # grad_input = grad_output

        # record state_factor in each iter
        # file = open("./state_factor_" + ctx.index + ".txt", "a")
        # state_factor = ctx.state_factor.to(device="cpu").numpy()
        # state_factor = state_factor.reshape(state_factor.shape[0])
        # for i in range(len(state_factor)):
        #     file.write("{} ".format(str(state_factor[i])))
        # file.write("\n")
        # file.close()

        # grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None, None, None, None, None
    

def cal_scale(num_bits_s, l_shift, x):
    length = len(x)
    x_int = x * 2 ** l_shift
    for i in range(length):
        if x_int[i] > 2 ** num_bits_s - 1:
            x_int[i] = 2 ** num_bits_s - 1
            # print("scale wrong!!! /n")
        elif x_int[i] < - 2 ** num_bits_s:
            x_int[i] = - 2 ** num_bits_s
            # print("scale wrong!!! /n")
    return x_int


def cal_scale_persample(num_bits_s, l_shift_max, x):
    l_shift = - torch.floor(torch.log2(x / 2 ** (num_bits_s - 1)))
    l_shift[l_shift > l_shift_max] = l_shift_max
    x_int = torch.round(x * 2 ** l_shift)
    return x_int, l_shift


class RecordGrad(InplaceFunction):

    @staticmethod
    def forward(ctx, input, index, iter):
        ctx.index = index
        ctx.iter = iter
        return input

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = record_grad(grad_output, ctx.index, ctx.iter)
        
        return grad_output, None, None, None, None, None, None, None


def record_grad(x, index, iter):
    return RecordTensor().apply(x, index, iter)
    

class RecordTensor(InplaceFunction):

    @staticmethod
    def forward(ctx, input, index, iter):
        ctx.index = index
        ctx.iter = iter

        output = input

        qparams = calculate_qparams(
            output, num_bits=4, flatten_dims=(2, -1), reduce_dim=None)
         
        if ctx.iter % 100 == 0:
            temp = qparams.range.to(device='cpu').numpy()
            sio.savemat('C:/wenjinguo/prj/cpt_cifar_3/cifar10/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_noendmask_GAaroundBN/gabn/sc_' + ctx.index + '_' + str(ctx.iter) + '.mat', 
                        {'data': temp[:, 0, 0, 0]})

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # straight-through estimator
        grad_input = grad_output
        return grad_input, None, None, None, None, None, None, None, None
    

class UniformQuantizeGrad(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
                reduce_dim=0, dequantize=True, signed=False, stochastic=True):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.flatten_dims = flatten_dims
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize(grad_output, num_bits=None,
                                  qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                                  dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
            grad_input = grad_input
        return grad_input, None, None, None, None, None, None, None
    

class UniformQuantizeGradPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
                reduce_dim=0, dequantize=True, signed=False, stochastic=True, prun_factor=0.2, 
                index='', iter=None):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.flatten_dims = flatten_dims
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        ctx.prun_factor = prun_factor
        ctx.index = index
        ctx.iter = iter
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize_prun(grad_output, num_bits=None,
                                       qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                                       dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False, prun_factor=ctx.prun_factor, 
                                       index=ctx.index, iter=ctx.iter)
            grad_input = grad_input
        return grad_input, None, None, None, None, None, None, None, None, None, None, None
    

class UniformQuantizeSCGradPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, 
                reduce_dim=0, dequantize=True, signed=False, stochastic=True, prun_factor=0.2, 
                index='', iter=None):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        ctx.prun_factor = prun_factor
        ctx.index = index
        ctx.iter = iter
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=(2, -1), reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize_psc_prun(grad_output, num_bits=None,
                                       qparams=qparams, reduce_dim=ctx.reduce_dim,
                                       dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False, prun_factor=ctx.prun_factor, 
                                       index=ctx.index, iter=ctx.iter)
            grad_input = grad_input
        return grad_input, None, None, None, None, None, None, None, None, None, None, None
    

class UniformQuantizeSCSGradPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, 
                reduce_dim=0, dequantize=True, signed=False, stochastic=True, prun_factor=0.2, sw=None,
                index='', iter=None):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        ctx.prun_factor = prun_factor
        ctx.index = index
        ctx.iter = iter
        ctx.sw = sw
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=(2, -1), reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize_pscs_prun(grad_output, num_bits=None,
                                       qparams=qparams, reduce_dim=ctx.reduce_dim,
                                       dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False, prun_factor=ctx.prun_factor, sw=ctx.sw,
                                       index=ctx.index, iter=ctx.iter)
            grad_input = grad_input
        return grad_input, None, None, None, None, None, None, None, None, None, None, None


class UniformQuantizeSCDGradPrun(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, 
                reduce_dim=0, dequantize=True, signed=False, stochastic=True, prun_factor=0.2, sw=None,
                index='', iter=None):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        ctx.prun_factor = prun_factor
        ctx.index = index
        ctx.iter = iter
        ctx.sw = sw
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=(2, -1), reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize_pscd_prun(grad_output, num_bits=None,
                                       qparams=qparams, reduce_dim=ctx.reduce_dim,
                                       dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False, prun_factor=ctx.prun_factor, sw=ctx.sw,
                                       index=ctx.index, iter=ctx.iter)
            grad_input = grad_input
        return grad_input, None, None, None, None, None, None, None, None, None, None, None 


class UniformQuantizeGradEF(InplaceFunction):

    @staticmethod
    def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
                reduce_dim=0, dequantize=True, signed=False, stochastic=True):
        ctx.num_bits = num_bits
        ctx.qparams = qparams
        ctx.flatten_dims = flatten_dims
        ctx.stochastic = stochastic
        ctx.signed = signed
        ctx.dequantize = dequantize
        ctx.reduce_dim = reduce_dim
        ctx.inplace = False
        ctx.grad_error = None
        ctx.ite = 0
        return input

    @staticmethod
    def backward(ctx, grad_output):
        qparams = ctx.qparams
        with torch.no_grad():
            if qparams is None:
                assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
                qparams = calculate_qparams(
                    grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                    reduce_type='extreme')

            grad_input = quantize(grad_output, num_bits=None,
                                  qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                                  dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
            
            # Error feedback
            if ctx.ite == 0:
                ctx.grad_error = grad_output - grad_input
            else:
                ctx.grad_error = ctx.grad_error + grad_output - grad_input

            grad_input = quantize(grad_output + ctx.grad_error, num_bits=None,
                                  qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
                                  dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
            
        return grad_input, None, None, None, None, None, None, None


def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None):
    out1 = F.conv2d(input.detach(), weight, bias,
                    stride, padding, dilation, groups)
    out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None,
                    stride, padding, dilation, groups)
    out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1))
    return out1 + out2 - out1.detach()


def linear_biprec(input, weight, bias=None, num_bits_grad=None):
    out1 = F.linear(input.detach(), weight, bias)
    out2 = F.linear(input, weight.detach(), bias.detach()
    if bias is not None else None)
    out2 = quantize_grad(out2, num_bits=num_bits_grad)
    return out1 + out2 - out1.detach()


def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False,
             stochastic=False, inplace=False):
    if qparams:
        if qparams.num_bits:
            return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                           stochastic, inplace)
    elif num_bits:
        return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic,
                                       inplace)

    return x


def quantize_prun(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False,
             stochastic=False, inplace=False, prun_factor=0.2, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizePrun().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                           stochastic, inplace, prun_factor, index, iter)
    elif num_bits:
        return UniformQuantizePrun().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic,
                                       inplace, prun_factor, index, iter)

    return x


def quantize_psc_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True, signed=False,
             stochastic=False, inplace=False, prun_factor=0.2, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, inplace, prun_factor, index, iter)
    elif num_bits:
        return UniformQuantizeSCPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed, stochastic,
                                       inplace, prun_factor, index, iter)

    return x


def quantize_pscs_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True, signed=False,
             stochastic=False, inplace=False, prun_factor=0.2, sw=None, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCSPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, inplace, prun_factor, sw, index, iter)
    elif num_bits:
        return UniformQuantizeSCSPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed, stochastic,
                                       inplace, prun_factor, sw, index, iter)

    return x


def quantize_pscd_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True, signed=False,
             stochastic=False, inplace=False, prun_factor=0.2, sw=None, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCDPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, inplace, prun_factor, sw, index, iter)
            # return UniformQuantizeSCDPrunPercent().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
            #                                stochastic, inplace, prun_factor, sw, index, iter)
    elif num_bits:
        return UniformQuantizeSCDPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed, stochastic,
                                       inplace, prun_factor, sw, index, iter)
        # return UniformQuantizeSCDPrunPercent().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
        #                                    stochastic, inplace, prun_factor, sw, index, iter)

    return x


def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                               stochastic)
    elif num_bits:
        return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                           stochastic)

    return x


def quantize_grad_prun(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True, prun_factor=0.05, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeGradPrun().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                               stochastic, prun_factor, index, iter)
    elif num_bits:
        return UniformQuantizeGradPrun().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                           stochastic, prun_factor, index, iter)

    return x


def quantize_psc_grad_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True, prun_factor=0.05, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                               stochastic, prun_factor, index, iter)
    elif num_bits:
        return UniformQuantizeSCGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, prun_factor, index, iter)

    return x


def quantize_pscs_grad_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True, prun_factor=0.05, sw=None, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCSGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                               stochastic, prun_factor, sw, index, iter)
    elif num_bits:
        return UniformQuantizeSCSGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, prun_factor, sw, index, iter)

    return x


def quantize_pscd_grad_prun(x, num_bits=None, qparams=None, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True, prun_factor=0.05, sw=None, index='', iter=None):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeSCDGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                               stochastic, prun_factor, sw, index, iter)
    elif num_bits:
        return UniformQuantizeSCDGradPrun().apply(x, num_bits, qparams, reduce_dim, dequantize, signed,
                                           stochastic, prun_factor, sw, index, iter)

    return x


def quantize_grad_ef(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True,
                  signed=False, stochastic=True):
    if qparams:
        if qparams.num_bits:
            return UniformQuantizeGradEF().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                               stochastic)
    elif num_bits:
        return UniformQuantizeGradEF().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed,
                                           stochastic)

    return x


class QuantMeasure(nn.Module):
    """docstring for QuantMeasure."""

    def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN,
                 inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False):
        super(QuantMeasure, self).__init__()
        self.register_buffer('running_zero_point', torch.zeros(*shape_measure))
        self.register_buffer('running_range', torch.zeros(*shape_measure))
        self.measure = measure
        if self.measure:
            self.register_buffer('num_measured', torch.zeros(1))
        self.flatten_dims = flatten_dims
        self.momentum = momentum
        self.dequantize = dequantize
        self.stochastic = stochastic
        self.inplace = inplace

    def forward(self, input, num_bits, qparams=None):

        if self.training or self.measure:
            if qparams is None:
                qparams = calculate_qparams(
                    input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme')
            with torch.no_grad():
                if self.measure:
                    momentum = self.num_measured / (self.num_measured + 1)
                    self.num_measured += 1
                else:
                    momentum = self.momentum
                self.running_zero_point.mul_(momentum).add_(
                    qparams.zero_point * (1 - momentum))
                self.running_range.mul_(momentum).add_(
                    qparams.range * (1 - momentum))
        else:
            qparams = QParams(range=self.running_range,
                              zero_point=self.running_zero_point, num_bits=num_bits, max_value=None)
        if self.measure:
            return input
        else:
            q_input = quantize(input, qparams=qparams, dequantize=self.dequantize,
                               stochastic=self.stochastic, inplace=self.inplace)
            return q_input


class QConv2d(nn.Conv2d):
    """docstring for QConv2d with channel smoothing."""

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True, weight_bw=8, act_bw=8, grad_bw=8):
        super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias)

        self.quantize_input = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1))
        self.stride = stride

        if self.bias is not None:
            self.quantize_b = UniformQuantizeBp()

        self.weight_bw = weight_bw
        self.act_bw = act_bw
        self.grad_bw = grad_bw

    def forward(self, input, index=None, iter=None, is_prun=False):

        if self.bias is not None:
            # qbias = quantize(
            #     self.bias, num_bits=self.num_bits,
            #     flatten_dims=(0, -1))
            qbias = self.quantize_b(self.bias, num_bits=self.weight_bw, flatten_dims=(0, -1))
        else:
            qbias = None

        weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bw, flatten_dims=(1, 1),
                                           reduce_dim=None)

        # per-outchannel quantization for forward
        qweight = UniformQuantizeW().apply(self.weight, self.weight_bw, weight_qparams, (1, -1), 0, True, False,
                                           False, False, 4, index, 0.9, iter)
        qinput = self.quantize_input(input, self.act_bw)
        output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)

        # if is_prun:
        #     output = quantize_grad_prun(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None, index=index)
        # else:
        #     output = quantize_grad_pl(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad_shift(output, num_bits_shift=2, num_bits_0=2, num_bits_1=2, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad_pl(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None, boundary=0.015)
        
        # output = quantize_psc_grad_prun(output, num_bits=num_grad_bits, reduce_dim=None, prun_factor=0.05, index=index, iter=iter)
        output = quantize_pscd_grad_prun(output, num_bits=self.grad_bw, reduce_dim=None, prun_factor=0.05, 
                                         sw=weight_qparams.range, index=index, iter=iter)

        # record scale of output
        # if index == '1_2_1' or index == '2_2_1' or index == '3_2_1':
        #     file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/activation/' + index + '_channel.txt', "a")
        #     temp = output.permute(1, 0, 2, 3)
        #     channel_range = temp.flatten(1, -1).max(-1).values - temp.flatten(1, -1).min(-1).values
        #     for i in range(output.shape[0]):
        #         file.write("{} ".format(channel_range[i]))
        #     file.write("\n")
        #     file.close()
        
        # file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/activation/' + index + '_sample_channel.txt', "a")
        # sample_channel_range = output.flatten(2, -1).max(2).values - output.flatten(2, -1).min(2).values
        # for i in range(output.shape[0]):
        #     for j in range(output.shape[1]):
        #         file.write("{} ".format(sample_channel_range[i][j]))
        # file.write("\n")
        # file.close()

        return output
        

    def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1,
                         error_bits=0, gc_bits=0):
        out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None,
                        stride, padding, dilation, groups)
        out2 = F.conv2d(input_bw.detach(), weight, bias,
                        stride, padding, dilation, groups)
        out1 = quantize_grad(out1, num_bits=error_bits)
        out2 = quantize_grad(out2, num_bits=gc_bits)
        return out1 + out2 - out2.detach()


class QConv2dSmooth(nn.Conv2d):
    """docstring for QConv2d with channel smoothing."""

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True, weight_bw=8, act_bw=8, grad_bw=8):
        super(QConv2dSmooth, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias)

        self.quantize_input = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1))
        self.stride = stride

        if self.bias is not None:
            self.quantize_b = UniformQuantizeBp()

        self.weight_bw = weight_bw
        self.act_bw = act_bw
        self.grad_bw = grad_bw

    def forward(self, input, index=None, iter=None, is_prun=False, is_val=False):

        if self.bias is not None:
            # qbias = quantize(
            #     self.bias, num_bits=self.num_bits,
            #     flatten_dims=(0, -1))
            qbias = self.quantize_b(self.bias, num_bits=self.weight_bw, flatten_dims=(0, -1))
        else:
            qbias = None

        weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bw, flatten_dims=(1, -1),
                                           reduce_dim=None)
        weight_qparams_bp = calculate_qparams_permute(self.weight, num_bits=self.weight_bw, flatten_dims=(1, -1),
                                           reduce_dim=None)

        # per-outchannel quantization for forward
        qweight = UniformQuantizeW().apply(self.weight, self.weight_bw, weight_qparams, (1, -1), 0, True, False,
                                           False, False, 4, index, 0.9, iter)
        qweight_bp = UniformQuantizeW().apply(self.weight, self.weight_bw, weight_qparams_bp, (1, -1), 0, True, False,
                                           False, False, 4, index, 0.9, iter)
        qinput = self.quantize_input(input, self.act_bw)
        if is_val:
            output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
        else:
            output = self.conv2d_quant_pscd(qinput, qinput, qweight, qweight_bp, bias=qbias, stride=self.stride, 
                                            padding=self.padding, dilation=self.dilation, groups=self.groups, sw=weight_qparams_bp.range, 
                                            index=index, iter=iter,
                                            error_bits=self.grad_bw, gc_bits=self.grad_bw)

        # if is_prun:
        #     output = quantize_grad_prun(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None, index=index)
        # else:
        #     output = quantize_grad_pl(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad_shift(output, num_bits_shift=2, num_bits_0=2, num_bits_1=2, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None)
        # output = quantize_grad_pl(output, num_bits=num_grad_bits, flatten_dims=(1, -1), reduce_dim=None, boundary=0.015)
        
        # output = quantize_psc_grad_prun(output, num_bits=num_grad_bits, reduce_dim=None, prun_factor=0.05, index=index, iter=iter)
        # output = shiftquantize_psc_grad_prun(output, num_bits=num_grad_bits, reduce_dim=None, prun_factor=0.05, index=index, iter=iter)
        # output = quantize_pscd_grad_prun(output, num_bits=num_grad_bits, reduce_dim=None, prun_factor=0.05, sw=weight_qparams.range, index=index, iter=iter)
        # output = shiftquantize_pscs_grad_prun(output, num_bits=num_grad_bits, reduce_dim=None, prun_factor=0.05, sw=weight_qparams.range, index=index, iter=iter)

        # record scale of output
        # if index == '1_2_1' or index == '2_2_1' or index == '3_2_1':
        #     file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/activation/' + index + '_channel.txt', "a")
        #     temp = output.permute(1, 0, 2, 3)
        #     channel_range = temp.flatten(1, -1).max(-1).values - temp.flatten(1, -1).min(-1).values
        #     for i in range(output.shape[0]):
        #         file.write("{} ".format(channel_range[i]))
        #     file.write("\n")
        #     file.close()
        
        # file = open('./cifar_10_tuning/UQ_wide_all4_piecewise_78000_INTV4_percsquant_nonDSQ_prunv3c_noendmask_dynamicbwwide_nofcprun_test/scale_analysis/activation/' + index + '_sample_channel.txt', "a")
        # sample_channel_range = output.flatten(2, -1).max(2).values - output.flatten(2, -1).min(2).values
        # for i in range(output.shape[0]):
        #     for j in range(output.shape[1]):
        #         file.write("{} ".format(sample_channel_range[i][j]))
        # file.write("\n")
        # file.close()

        return output
    
    def conv2d_quant_pscd(self, input_fw, input_bw, weight_fw, weight_bp, bias=None, stride=1, padding=0, dilation=1, groups=1, sw=None, index=None, iter=None,
                         error_bits=0, gc_bits=0):
        # bp of input act
        out1 = F.conv2d(input_fw, weight_fw.detach(), bias.detach() if bias is not None else None,
                        stride, padding, dilation, groups)
        # bp of weight
        out2 = F.conv2d(input_bw.detach(), weight_bp, bias,
                        stride, padding, dilation, groups)
        out1 = quantize_pscd_grad_prun(out1, num_bits=error_bits, reduce_dim=None, prun_factor=0.05, sw=sw, index=index, iter=iter)
        out2 = quantize_pscd_grad_prun(out2, num_bits=gc_bits, reduce_dim=None, prun_factor=0.05, sw=sw, index=index, iter=iter)
        return out1 + out2 - out2.detach()

    def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1,
                         error_bits=0, gc_bits=0):
        out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None,
                        stride, padding, dilation, groups)
        out2 = F.conv2d(input_bw.detach(), weight, bias,
                        stride, padding, dilation, groups)
        out1 = quantize_grad(out1, num_bits=error_bits)
        out2 = quantize_grad(out2, num_bits=gc_bits)
        return out1 + out2 - out2.detach()


class QLinear(nn.Linear):
    """docstring for Linear."""

    def __init__(self, in_features, out_features, bias=True, weight_bw=8, act_bw=8, grad_bw=8):
        super(QLinear, self).__init__(in_features, out_features, bias)

        self.quantize_input = QuantMeasure(shape_measure=(1, 1), flatten_dims=(1, -1))
        if self.bias is not None:
            self.quantize_b = UniformQuantizeBp()

        self.weight_bw = weight_bw
        self.act_bw = act_bw
        self.grad_bw = grad_bw

    def forward(self, input, index=None, iter=None):
        if self.bias is not None:
            # qbias = quantize(
            #     self.bias, num_bits=self.num_bits,
            #     flatten_dims=(0, -1))
            qbias = self.quantize_b(self.bias, num_bits=self.weight_bw, flatten_dims=(0, -1))
        else:
            qbias = None

        weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bw, flatten_dims=(1, -1),
                                           reduce_dim=None)
        qweight = UniformQuantizeW().apply(self.weight, self.weight_bw, weight_qparams, (1, -1), 0, True, False, 
                                             False, False, 4, index, 0.9, iter)

        # print("fc input.shape: {}".format(input.shape))
        qinput = self.quantize_input(input, self.act_bw)
        output = F.linear(qinput, qweight, qbias)
        output = quantize_grad(output, num_bits=self.grad_bw, flatten_dims=(2, -1), reduce_dim=None)
        return output
        

    def linear_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1,
                         error_bits=0, gc_bits=0):
        out1 = F.linear(input_fw, weight.detach(), bias.detach() if bias is not None else None,
                        stride, padding, dilation, groups)
        out2 = F.linear(input_bw.detach(), weight, bias,
                        stride, padding, dilation, groups)
        out1 = quantize_grad(out1, num_bits=error_bits)
        out2 = quantize_grad(out2, num_bits=gc_bits)
        return out1 + out2 - out2.detach()


if __name__ == '__main__':
    x = torch.rand(2, 3)
    x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True)
    print(x)
    print(x_q)
