# *
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository.  If not, see <http://www.gnu.org/licenses/>.
# *

import torch
import time
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
from .quant_utils import *
import sys


class QuantAct(Module):
	"""
	Class to quantize given activations
	"""

	def __init__(self,
	             activation_bit,
	             full_precision_flag=False,
	             running_stat=True,
				 beta=0.9):
		"""
		activation_bit: bit-setting for activation
		full_precision_flag: full precision or not
		running_stat: determines whether the activation range is updated or froze
		"""
		super(QuantAct, self).__init__()
		self.activation_bit = activation_bit
		self.full_precision_flag = full_precision_flag
		self.running_stat = running_stat
		self.register_buffer('x_min', torch.zeros(1))
		self.register_buffer('x_max', torch.zeros(1))
		self.register_buffer('beta', torch.Tensor([beta]))
		self.register_buffer('beta_t', torch.ones(1))
		self.act_function = AsymmetricQuantFunction.apply

	def __repr__(self):
		return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format(
			self.__class__.__name__, self.activation_bit,
			self.full_precision_flag, self.running_stat, self.x_min.item(),
			self.x_max.item())

	def fix(self):
		"""
		fix the activation range by setting running stat
		"""
		self.running_stat = False

	def unfix(self):
		"""
		fix the activation range by setting running stat
		"""
		self.running_stat = True

	def forward(self, x):
		"""
		quantize given activation x
		"""

		if self.running_stat:
			x_min = x.data.min()
			x_max = x.data.max()
			# in-place operation used on multi-gpus
			# self.x_min += -self.x_min + min(self.x_min, x_min)
			# self.x_max += -self.x_max + max(self.x_max, x_max)

			self.beta_t = self.beta_t * self.beta
			self.x_min = (self.x_min * self.beta + x_min * (1 - self.beta))/(1 - self.beta_t)
			self.x_max = (self.x_max * self.beta + x_max * (1 - self.beta)) / (1 - self.beta_t)

		if not self.full_precision_flag:
			quant_act = self.act_function(x, self.activation_bit, self.x_min,
			                              self.x_max)
			return quant_act
		else:
			return x

class QuantAct_MSE(Module):
	"""
	Class to quantize given activations
	"""

	def __init__(self,
				 activation_bit,
				 full_precision_flag=False,
				 running_stat=True,
				 beta=0.9):
		"""
		activation_bit: bit-setting for activation
		full_precision_flag: full precision or not
		running_stat: determines whether the activation range is updated or froze
		"""
		super(QuantAct_MSE, self).__init__()
		self.activation_bit = activation_bit
		self.full_precision_flag = full_precision_flag
		self.running_stat = running_stat
		self.register_buffer('x_min', torch.zeros(1))
		self.register_buffer('x_max', torch.zeros(1))
		self.register_buffer('beta', torch.Tensor([beta]))
		self.register_buffer('beta_t', torch.ones(1))

		self.register_buffer('cur_x_min', torch.zeros(1))
		self.register_buffer('cur_x_max', torch.zeros(1))
		self.act_function = AsymmetricQuantFunction.apply

	def __repr__(self):
		return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format(
			self.__class__.__name__, self.activation_bit,
			self.full_precision_flag, self.running_stat, self.x_min.item(),
			self.x_max.item())

	def fix(self):
		"""
		fix the activation range by setting running stat
		"""
		self.running_stat = False

	def unfix(self):
		"""
		fix the activation range by setting running stat
		"""
		self.running_stat = True

	def forward(self, x):
		"""
		quantize given activation x
		"""
		# print(self.running_stat, self.x_min, self.x_max)
		if self.running_stat:
			x_min = x.data.min()
			x_max = x.data.max()

			self.cur_x_min = x_min
			self.cur_x_max = x_max


			x_clone = x.clone().detach()
			# in-place operation used on multi-gpus
			# self.x_min += -self.x_min + min(self.x_min, x_min)
			# self.x_max += -self.x_max + max(self.x_max, x_max)
			best_score = 1e+10
			# print('mse find best max/min')
			for i in range(80):
				new_min = x_min * (1.0 - (i * 0.01))
				new_max = x_max * (1.0 - (i * 0.01))

				quant_act = find_MSESmallest(x_clone, self.activation_bit, new_min, new_max)
				# L_p norm minimization as described in LAPQ
				# https://arxiv.org/abs/1911.07190
				score = lp_loss(x_clone, quant_act, p=2.4, reduction='all')
				if score < best_score:
					best_score = score
					save_min = new_min
					save_max = new_max

			self.beta_t = self.beta_t * self.beta
			self.x_min = self.x_min * self.beta + save_min * (1 - self.beta)
			self.x_max = self.x_max * self.beta + save_max * (1 - self.beta)
			# print(self.x_min, self.x_max, save_min, save_max, x.data.min(),  x.data.max())

		if not self.full_precision_flag:
			quant_act = self.act_function(x, self.activation_bit, self.x_min,
										  self.x_max)
			return quant_act
		else:
			return x

class Quant_Linear(Module):
	"""
	Class to quantize given linear layer weights
	"""

	def __init__(self, weight_bit, full_precision_flag=False):
		"""
		weight: bit-setting for weight
		full_precision_flag: full precision or not
		running_stat: determines whether the activation range is updated or froze
		"""
		super(Quant_Linear, self).__init__()
		self.full_precision_flag = full_precision_flag
		self.weight_bit = weight_bit
		self.weight_function = AsymmetricQuantFunction.apply

	def __repr__(self):
		s = super(Quant_Linear, self).__repr__()
		s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
			self.weight_bit, self.full_precision_flag)
		return s

	def set_param(self, linear):
		self.in_features = linear.in_features
		self.out_features = linear.out_features
		self.weight = Parameter(linear.weight.data.clone())
		try:
			self.bias = Parameter(linear.bias.data.clone())
		except AttributeError:
			self.bias = None

	def forward(self, x):
		"""
		using quantized weights to forward activation x
		"""
		w = self.weight
		x_transform = w.data.detach()
		w_min = x_transform.min(dim=1).values
		w_max = x_transform.max(dim=1).values
		if not self.full_precision_flag:
			w = self.weight_function(self.weight, self.weight_bit, w_min,
			                         w_max)
		else:
			w = self.weight
		return F.linear(x, weight=w, bias=self.bias)


class Quant_Conv2d(Module):
	"""
	Class to quantize given convolutional layer weights
	"""

	def __init__(self, weight_bit, full_precision_flag=False):
		super(Quant_Conv2d, self).__init__()
		self.full_precision_flag = full_precision_flag
		self.weight_bit = weight_bit
		self.weight_function = AsymmetricQuantFunction.apply

	def __repr__(self):
		s = super(Quant_Conv2d, self).__repr__()
		s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
			self.weight_bit, self.full_precision_flag)
		return s

	def set_param(self, conv):
		self.in_channels = conv.in_channels
		self.out_channels = conv.out_channels
		self.kernel_size = conv.kernel_size
		self.stride = conv.stride
		self.padding = conv.padding
		self.dilation = conv.dilation
		self.groups = conv.groups
		self.weight = Parameter(conv.weight.data.clone())
		try:
			self.bias = Parameter(conv.bias.data.clone())
		except AttributeError:
			self.bias = None

	def forward(self, x):
		"""
		using quantized weights to forward activation x
		"""
		w = self.weight
		x_transform = w.data.contiguous().view(self.out_channels, -1)
		w_min = x_transform.min(dim=1).values
		w_max = x_transform.max(dim=1).values
		if not self.full_precision_flag:
			w = self.weight_function(self.weight, self.weight_bit, w_min,
			                         w_max)
		else:
			w = self.weight

		return F.conv2d(x, w, self.bias, self.stride, self.padding,
		                self.dilation, self.groups)


# *
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository.  If not, see <http://www.gnu.org/licenses/>.
# *

import torch
import time
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
from .quant_utils import *
import sys


class QuantAct_DSG(Module):
    """
    Class to quantize given activations
    """

    def __init__(self,
                 activation_bit,
                 full_precision_flag=False,
                 running_stat=True,
                 beta=0.9):
        """
        activation_bit: bit-setting for activation
        full_precision_flag: full precision or not
        running_stat: determines whether the activation range is updated or froze
        """
        super(QuantAct_DSG, self).__init__()
        self.activation_bit = activation_bit
        self.full_precision_flag = full_precision_flag
        self.running_stat = running_stat
        self.register_buffer('x_min', torch.zeros(1))
        self.register_buffer('x_max', torch.zeros(1))
        self.register_buffer('beta', torch.Tensor([beta]))
        self.register_buffer('beta_t', torch.ones(1))
        self.act_function = SymmetricQuantFunction_DSG.apply

    def __repr__(self):
        return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format(
            self.__class__.__name__, self.activation_bit,
            self.full_precision_flag, self.running_stat, self.x_min.item(),
            self.x_max.item())

    def fix(self):
        """
        fix the activation range by setting running stat
        """
        self.running_stat = False

    def unfix(self):
        """
        fix the activation range by setting running stat
        """
        self.running_stat = True

    def forward(self, x):
        """
        quantize given activation x
        """

        if self.running_stat:
            x_min = x.data.min()
            x_max = x.data.max()
            # in-place operation used on multi-gpus
            # self.x_min += -self.x_min + min(self.x_min, x_min)
            # self.x_max += -self.x_max + max(self.x_max, x_max)
            if x_min.abs() > x_max.abs():
                x_min = -x_min.abs()
                x_max = x_min.abs()
            else:
                x_min = -x_max.abs()
                x_max = x_max.abs()
            self.beta_t = self.beta_t * self.beta
            self.x_min = (self.x_min * self.beta + x_min *
                          (1 - self.beta))/(1 - self.beta_t)
            self.x_max = (self.x_max * self.beta + x_max *
                          (1 - self.beta)) / (1 - self.beta_t)

        if not self.full_precision_flag:
            quant_act = self.act_function(x, self.activation_bit, self.x_min,
                                          self.x_max)
            return quant_act
        else:
            return x


class QuantLinear_DSG(Module):
    """
    Class to quantize given linear layer weights
    """

    def __init__(self, weight_bit, full_precision_flag=False):
        """
        weight: bit-setting for weight
        full_precision_flag: full precision or not
        running_stat: determines whether the activation range is updated or froze
        """
        super(QuantLinear_DSG, self).__init__()
        self.full_precision_flag = full_precision_flag
        self.weight_bit = weight_bit
        self.weight_function = SymmetricQuantFunction_DSG.apply

    def __repr__(self):
        s = super(QuantLinear_DSG, self).__repr__()
        s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
            self.weight_bit, self.full_precision_flag)
        return s

    def set_param(self, linear):
        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.weight = Parameter(linear.weight.data.clone())
        try:
            self.bias = Parameter(linear.bias.data.clone())
        except AttributeError:
            self.bias = None

    def forward(self, x):
        """
        using quantized weights to forward activation x
        """
        w = self.weight
        x_transform = w.data.detach()
        w_min = -x_transform.abs().max(dim=1).values
        w_max = x_transform.abs().max(dim=1).values
        if not self.full_precision_flag:
            w = self.weight_function(self.weight, self.weight_bit, w_min,
                                     w_max)
        else:
            w = self.weight
        return F.linear(x, weight=w, bias=self.bias)


class QuantConv2d_DSG(Module):
    """
    Class to quantize given convolutional layer weights
    """

    def __init__(self, weight_bit, full_precision_flag=False):
        super(QuantConv2d_DSG, self).__init__()
        self.full_precision_flag = full_precision_flag
        self.weight_bit = weight_bit
        self.weight_function = SymmetricQuantFunction_DSG.apply

    def __repr__(self):
        s = super(QuantConv2d_DSG, self).__repr__()
        s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
            self.weight_bit, self.full_precision_flag)
        return s

    def set_param(self, conv):
        self.in_channels = conv.in_channels
        self.out_channels = conv.out_channels
        self.kernel_size = conv.kernel_size
        self.stride = conv.stride
        self.padding = conv.padding
        self.dilation = conv.dilation
        self.groups = conv.groups
        self.weight = Parameter(conv.weight.data.clone())
        try:
            self.bias = Parameter(conv.bias.data.clone())
        except AttributeError:
            self.bias = None

    def forward(self, x):
        """
        using quantized weights to forward activation x
        """
        w = self.weight
        x_transform = w.data.contiguous().view(self.out_channels, -1)
        w_min = -x_transform.abs().max(dim=1).values
        w_max = x_transform.abs().max(dim=1).values
        if not self.full_precision_flag:
            w = self.weight_function(self.weight, self.weight_bit, w_min,
                                     w_max)
        else:
            w = self.weight

        return F.conv2d(x, w, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)