import torch
import torch.nn.functional as F
from torch.nn.modules.linear import Linear
import math
from torch.nn.parameter import Parameter
from ._quan_base_plus import _Conv2dQ, Qmodes, _LinearQ, _ActQ, _LinearQ_v2


__all__ = ['Conv2dLSQ', 'LinearLSQ', 'ActLSQ', 'LinearMCN', 'LinearLSQ_v2']


class FunLSQ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weight, alpha, g, Qn, Qp):
        assert alpha > 0, 'alpha = {}'.format(alpha)
        ctx.save_for_backward(weight, alpha)
        ctx.other = g, Qn, Qp
        q_w = (weight / alpha).round().clamp(Qn, Qp)
        w_q = q_w * alpha
        return w_q

    @staticmethod
    def backward(ctx, grad_weight):
        weight, alpha = ctx.saved_tensors
        g, Qn, Qp = ctx.other
        q_w = weight / alpha
        indicate_small = (q_w < Qn).float()
        indicate_big = (q_w > Qp).float()
        indicate_middle = 1.0 - indicate_small - indicate_big 
        grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * (
            -q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0)
        grad_weight = indicate_middle * grad_weight
        return grad_weight, grad_alpha, None, None, None


def grad_scale(x, scale):
    y = x
    y_grad = x * scale
    return y.detach() - y_grad.detach() + y_grad


def round_pass(x):
    y = x.round()
    y_grad = x
    return y.detach() - y_grad.detach() + y_grad


class Conv2dLSQ(_Conv2dQ):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, nbits_w=4, mode=Qmodes.kernel_wise, **kwargs):
        super(Conv2dLSQ, self).__init__(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
            stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
            nbits=nbits_w, mode=mode)
        self.act = ActLSQ(in_features=in_channels, nbits_a=nbits_w)
        self.task_buffer = 0  # task
        

    def forward(self, x):
        task = self.task_buffer
        if self.alpha is None:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)
        Qn = -2 ** (self.nbits - 1)
        Qp = 2 ** (self.nbits - 1) - 1
        if self.training and self.init_state == 0:
            self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.init_state.fill_(1)
        g = 1.0 / math.sqrt(self.weight.numel() * Qp)

        alpha = grad_scale(self.alpha, g)
        
        
        alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha

        x = self.act(x, task)
        return F.conv2d(x, w_q, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    def change_buffer(self, task):
        self.task_buffer = task



class LinearLSQ_v2(_LinearQ_v2):
    def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs):
        super(LinearLSQ_v2, self).__init__(in_features=in_features,
                                        out_features=out_features, bias=bias, nbits=nbits_w, mode=Qmodes.kernel_wise)
        self.act = ActLSQ(in_features=in_features, nbits_a=nbits_w)

    def forward(self, x):
        if self.alpha is None:
            return F.linear(x, self.weight, self.bias)
        Qn = -2 ** (self.nbits - 1)
        Qp = 2 ** (self.nbits - 1) - 1
        if self.training and self.init_state == 0:
            self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.beta.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.init_state.fill_(1)
        g = 1.0 / math.sqrt(self.weight.numel() * Qp)

        alpha = grad_scale(self.alpha, g)
        beta = grad_scale(self.beta, g)
        alpha = alpha.unsqueeze(1)
        beta = beta.unsqueeze(0)
        scale = alpha @ beta   
        w_q = round_pass((self.weight / scale).clamp(Qn, Qp)) * scale

        x = self.act(x)
        return F.linear(x, w_q, self.bias)


class LinearLSQ(_LinearQ):
    def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs):
        super(LinearLSQ, self).__init__(in_features=in_features,
                                        out_features=out_features, bias=bias, nbits=nbits_w, mode=Qmodes.kernel_wise)
        self.act = ActLSQ(in_features=in_features, nbits_a=nbits_w)


    def qw(self, weight):
        Qn = -2 ** (self.nbits - 1)
        Qp = 2 ** (self.nbits - 1) - 1
        if self.training and self.init_state == 0:
            self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.init_state.fill_(1)
        g = 1.0 / math.sqrt(self.weight.numel() * Qp)

        alpha = grad_scale(self.alpha, g)

        alpha = alpha.unsqueeze(1)
        w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
        return w_q

    def forward(self, x, task):
        if self.alpha is None:
            return F.linear(x, self.weight, self.bias)

        w_q = self.qw(self.weight)

        x = self.act(x, task)
        return F.linear(x, w_q, self.bias)


class LinearMCN(Linear):
    def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs):
        super(LinearMCN, self).__init__(in_features=in_features,
                                        out_features=out_features, bias=bias)
        self.act = ActLSQ(in_features=in_features, nbits_a=nbits_w)
        self.MCF_Function = MCF_Function.apply
        self.nbits = nbits_w
        self.generate_MFilters()
        self.out_features = out_features
        
    def generate_MFilters(self):
        self.MFilters = Parameter(torch.randn(self.out_features, 1))

    def forward(self, x):
        if self.MFilters is None:
            return F.linear(x, self.weight, self.bias)
        Qn = -2 ** (self.nbits - 1)
        Qp = 2 ** (self.nbits - 1) - 1
        bin = 0.02
        self.MFilters.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))

        x = self.act(x)
        weight_bin = (self.weight / self.MFilters).clamp(Qn, Qp).round() # * bin
        w_q = self.MCF_Function(self.weight, self.MFilters, weight_bin)
        
        return F.linear(x, w_q, self.bias)

class MCF_Function(torch.autograd.Function):

    @staticmethod
    def forward(ctx, weight, MScale, weight_bin):
        weight = weight_bin * MScale

        ctx.save_for_backward(weight, MScale, weight_bin)
        return weight

    @staticmethod
    def backward(ctx, gradOutput):
        weight, MScale, weight_bin = ctx.saved_tensors

        nChannel = MScale.size()[0]
        nOutputPlane = gradOutput.size()[0]
        nInputPlane = gradOutput.size()[1]
        para_loss = 0.0001

        target1 = para_loss * (weight - weight_bin * MScale)
        gradWeight = target1 + (gradOutput * MScale)
        
        target2 = (weight - weight_bin * MScale) * weight_bin
        grad_h2_sum = torch.sum(gradOutput * weight,keepdim=True, dim=1)

        grad_target2 = torch.sum(target2, keepdim=True, dim=1)

        gradMScale = grad_h2_sum - para_loss * grad_target2
        gradweight_bin = gradOutput * 0
        return gradWeight, gradMScale, gradweight_bin


class ActLSQ(_ActQ):
    def __init__(self, in_features, nbits_a=4, mode=Qmodes.layer_wise, **kwargs):
        super(ActLSQ, self).__init__(in_features=in_features, nbits=nbits_a, mode=mode)
        self.grad_zero_coef = 1
    def forward(self, x, task):
        if self.alpha is None:
            return x

        if self.training and self.init_state[task] == 0:
            if x.min() < -1e-5:
                self.signed.data[task] = 1
            signed = self.signed[task]
            if signed == 1:
                Qn = -2 ** (self.nbits - 1)
                Qp = 2 ** (self.nbits - 1) - 1
            else:
                Qn = 0
                Qp = 2 ** self.nbits - 1
            
            init_a = 2 * x.abs().mean() / math.sqrt(Qp)
            self.alpha.data[task] = init_a

            self.init_state[task] = 1
        

        alpha = self.alpha[task]
        zero_point = self.zero_point[task]
        signed = self.signed[task]
        
        if signed == 1:
            Qn = -2 ** (self.nbits - 1)
            Qp = 2 ** (self.nbits - 1) - 1
        else:
            Qn = 0
            Qp = 2 ** self.nbits - 1

        if self.grad_zero_coef == 1:
            g = 1.0 / math.sqrt(x.numel() * Qp)
        else:
            g = 0

        zero_point = (zero_point.round() - zero_point).detach() + zero_point
        alpha = grad_scale(alpha, g)
        zero_point = grad_scale(zero_point, g)
        if len(x.shape)==2:
            alpha = alpha.unsqueeze(0)
            zero_point = zero_point.unsqueeze(0)
        elif len(x.shape)==3:
            if x.shape[0] == alpha.shape[0]:
                alpha = alpha.unsqueeze(1).unsqueeze(2)
                zero_point = zero_point.unsqueeze(1).unsqueeze(2)
            elif x.shape[1] == alpha.shape[0]:
                alpha = alpha.unsqueeze(0).unsqueeze(2)
                zero_point = zero_point.unsqueeze(0).unsqueeze(2)
            elif x.shape[2] == alpha.shape[0]:
                alpha = alpha.unsqueeze(0).unsqueeze(0)
                zero_point = zero_point.unsqueeze(0).unsqueeze(0)
        elif len(x.shape)==4:
            if x.shape[0] == alpha.shape[0]:
                alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3)
                zero_point = zero_point.unsqueeze(1).unsqueeze(2).unsqueeze(3)
            elif x.shape[1] == alpha.shape[0]:
                alpha = alpha.unsqueeze(0).unsqueeze(2).unsqueeze(3)
                zero_point = zero_point.unsqueeze(0).unsqueeze(2).unsqueeze(3)
            elif x.shape[2] == alpha.shape[0]:
                alpha = alpha.unsqueeze(0).unsqueeze(0).unsqueeze(3)
                zero_point = zero_point.unsqueeze(0).unsqueeze(0).unsqueeze(3)
            elif x.shape[3] == alpha.shape[0]:
                alpha = alpha.unsqueeze(0).unsqueeze(0).unsqueeze(0)
                zero_point = zero_point.unsqueeze(0).unsqueeze(0).unsqueeze(0)

        x = round_pass((x / alpha + zero_point).clamp(Qn, Qp))
        x = (x - zero_point) * alpha

        return x

    def change_grad(self, g):
        self.grad_zero_coef = g
