import torch
from torch import nn
from quant.quant_layer import UniformAffineQuantizer, round_ste
import torch.nn.functional as F


class AdaRoundQuantizer(nn.Module):
    """
    Adaptive Rounding Quantizer, used to optimize the rounding policy
    by reconstructing the intermediate output.
    Based on
     Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568

    :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer
    :param round_mode: controls the forward pass in this quantizer
    :param weight_tensor: initialize alpha
    """

    def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'):
        super(AdaRoundQuantizer, self).__init__()
        # copying all attributes from UniformAffineQuantizer
        self.n_bits = uaq.n_bits
        self.sym = uaq.sym
        self.delta = uaq.delta
        self.zero_point = uaq.zero_point
        self.n_levels = uaq.n_levels
        
        self.round_mode = round_mode
        self.alpha = None
        self.alphaChannel = None
        self.soft_targets = False
        self.isFC = len(self.delta.shape) != 4
        self.DW = False
        
        self.normal_sigmoid = False
        
        # params for sigmoid function
        self.gamma, self.zeta = -0.1, 1.1
        self.beta = 2/3
        
        self.bias_ch_quant = uaq.bias_ch_quant
        self.shiftTarget = uaq.shiftTargets
        
        self.init_alpha(x=weight_tensor.clone())

    def forward(self, x):
        if self.round_mode == 'nearest':
            x_int = torch.round(x / self.delta)
        elif self.round_mode == 'nearest_ste':
            x_int = round_ste(x / self.delta)
        elif self.round_mode == 'stochastic':
            x_floor = torch.floor(x / self.delta)
            rest = (x / self.delta) - x_floor  # rest of rounding
            x_int = x_floor + torch.bernoulli(rest)
            print('Draw stochastic sample')
        elif self.round_mode == 'learned_hard_sigmoid':
            x_floor = torch.floor(x / self.delta)
            if self.soft_targets:
                x_int = x_floor + self.get_soft_targets()
            else:
                x_int = x_floor + (self.alpha >= 0).float()
        else:
            raise ValueError('Wrong rounding mode')

        x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1)
        x_float_q = (x_quant - self.zero_point) * self.delta
        if self.bias_ch_quant:
            x_float_q = self.shifted(x_float_q)
        else:
            x_float_q = x_float_q * self.alphaChannel
        return x_float_q
    
    def shifted(self, x):
        p = self.get_sig_soft_targets()
        if p.dim() == 2: #For small size group(output channel dim)
            if self.DW:
                p = p.unsqueeze(1)
            else:                
                p = p.unsqueeze(0)
        if self.soft_targets:#hard target
            if self.isFC:
                x_out = x*(self.shiftTarget[0] * p[:, :, 0])
                for i in range(1, len(self.shiftTarget)):
                    x_out +=  x*(self.shiftTarget[0] * p[:, :, i])
            else:
                p = p.unsqueeze(-1).unsqueeze(-1)
                x_out = x*(self.shiftTarget[0] * p[:, :, 0, :, :])
                for i in range(1, len(self.shiftTarget)):
                    x_out += x*(self.shiftTarget[i] * p[:, :, i, :, :])
        else:#hard target
            max_index = torch.argmax(p, dim=-1)
            x_out = x*self.shiftTarget[0]
            for i in range(1, len(self.shiftTarget)):
                mask = max_index == i
                if not self.isFC:
                    mask = mask.unsqueeze(-1).unsqueeze(-1)
                x_out = torch.where(mask, x*self.shiftTarget[i], x_out)
        return x_out

    def get_sig_soft_targets(self):
        if self.normal_sigmoid:
            p = torch.clamp(F.softmax(self.alphaChannel, dim=-1) * (self.zeta - self.gamma) + self.gamma, 0, 1)
            p_sum = torch.sum(p, dim=-1).unsqueeze(-1)
            p = p / p_sum
        else:
            p = F.softmax(self.alphaChannel, dim=-1)
        return p

    def get_soft_targets(self):
        return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)

    def init_alphaCh(self, x: torch.Tensor, device='cuda'):
        shiftNum = len(self.shiftTarget)
        alphaCh_Num = x.shape[1]
        if x.shape[1] == 1 and x.shape[0] > 1:
            self.DW = True
            alphaCh_Num = x.shape[0]
        alphach = torch.ones((alphaCh_Num, shiftNum), dtype=torch.float, device=device)
        return alphach

    def init_alpha(self, x: torch.Tensor):
        x_floor = torch.floor(x / self.delta)
        if self.round_mode == 'learned_hard_sigmoid':
            print('Init alpha to be FP32')
            rest = (x / self.delta) - x_floor  # rest of rounding [0, 1)
            alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1)  # => sigmoid(alpha) = rest
            self.alpha = nn.Parameter(alpha)
        else:
            raise NotImplementedError

        if self.bias_ch_quant:
            shiftTarget = self.shiftTarget
            print(f"Optimal shift candidates: ", shiftTarget)
            self.alphaChannel = self.init_alphaCh(x, device=self.alpha.device)
            self.alphaChannel = nn.Parameter(self.alphaChannel)
        else:
            if x.shape[1] == 1 and x.shape[0] > 1:
                self.DW = True
                if x.dim() == 2:
                    self.alphaChannel = nn.Parameter(torch.ones((x.shape[0], 1), device=self.alpha.device))
                else:
                    self.alphaChannel = nn.Parameter(torch.ones((x.shape[0], 1, 1, 1), device=self.alpha.device))
            else:
                if x.dim() == 2:
                    self.alphaChannel = nn.Parameter(torch.ones((1, x.shape[1]), device=self.alpha.device))
                else:
                    self.alphaChannel = nn.Parameter(torch.ones((1, x.shape[1], 1, 1), device=self.alpha.device))
