import torch
import torch.nn as nn
import sys
from quantize.quantizer import round_ste, clamp_ste, floor_ste

class TwoLevelFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, th):
        out2 = (input >= 2. * th).float()
        out1 = (input >= 1. * th).float() * (1. - out2)
        out = out1 * th + out2 * 2. * th
        input = ((input.detach() >= 0.5 * th) * (input.detach() <= 2.5 * th)).float()       
        ctx.save_for_backward(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (tmp,) = ctx.saved_tensors
        grad_input = grad_output * tmp
        return grad_input, None


class FourLevelFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, th):       
        out4 = (input >= 4. * th).float()
        out3 = (input >= 3. * th).float() * (1. - out4)
        out2 = (input >= 2. * th).float() * (1. - out4) * (1. - out3)
        out1 = (input >= 1. * th).float() * (1. - out4) * (1. - out3) * (1. - out2)
        out = out1 * th + out2 * 2. * th + out3 * 3. * th + out4 * 4. * th
        input = ((input.detach() >= 0.5 * th) * (input.detach() <= 4.5 * th)).float()
        ctx.save_for_backward(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (tmp,) = ctx.saved_tensors
        grad_input = grad_output * tmp
        return grad_input, None



class MultiLevelFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, th, L, sigma=0.5):
        # k = (input / th).floor().clamp(0, L-1)
        # out = input.floor()
        out =  floor_ste(input / th)
        # out = out.add(I_z)
        out = out.clamp(0, L)
        # out = out.sub(I_z)
        # out = k * th
        out = out.mul(th)
        ctx.save_for_backward(input, th)
        ctx.L = L
        ctx.sigma = sigma
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, th = ctx.saved_tensors
        L = ctx.L
        sigma = ctx.sigma

        grad_input = torch.zeros_like(input)
        for k in range(1, L + 1):
            center = k * th
            surrogate = torch.sigmoid((input - center) / sigma)
            grad_surrogate = surrogate * (1 - surrogate) / sigma
            grad_input += grad_surrogate

        grad_input = grad_output * grad_input
        return grad_input, None, None, None

class LMHTNeuron(nn.Module):
    def __init__(self, L: int, ori, T=2, avg=True):
        super(LMHTNeuron, self).__init__()
        self.register_parameter('scale', ori.scale)
        if ori.zero_point is not None:
            self.register_parameter('zero_point', ori.zero_point)
        self.v = None
        self.avg = avg
        # self.inital_mem = inital_mem * self.scale
        # self.inital_mem = self.inital_mem.to(self.scale.device)
        self.act = MultiLevelFunction.apply
        # if L == 2:
        #     self.act = TwoLevelFunction.apply
        # elif L == 4:
        #     self.act = FourLevelFunction.apply

        #self.alpha = nn.Parameter(torch.tensor([0.]), requires_grad=True)
        # self.alpha = nn.Parameter(torch.tensor([0.]), requires_grad=False)
        # self.mask = nn.Parameter(torch.zeros((T, T, 1, 1, 1, 1)), requires_grad=True)
        # self.mask_linear = nn.Parameter(torch.zeros((T, T, 1, 1)), requires_grad=True)
        self.T = T
        self.L = L
        # self.scale = 1.
        self.quantized_shape = ori.quantized_shape
        self.group_size = ori.group_size
        self.mode = ori.mode
        self.asym = ori.asym
        self.disable_zero_point_in_sym = ori.disable_zero_point_in_sym
        self.activation_clipping = ori.activation_clipping
        self.enable = True
        self.qmin = ori.qmin
        self.qmax = ori.qmax

        
    def forward(self, x: torch.Tensor):
        """
        x: (TB, lenSeq, dim)
        """
        
        scale = clamp_ste(self.scale, 1e-4, 1e4)

        if self.asym or not self.disable_zero_point_in_sym:
            round_zero_point = clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) if self.zero_point is not None else None
        T, bs, n, dim1 = x.shape
        x_reshaped = x.reshape( T, bs, n, -1, self.group_size)
        self.v = torch.ones_like(x_reshaped[ 0, ...]).mul(scale)*0.5
        # v_th = scale/(self.T*self.L)
        I_z = round_zero_point / self.T
        I_z = I_z.mul(scale)
        # I_z = round_zero_point 
        spike_pot = []
        if self.avg: 
            # x_s = x_reshaped.sum(dim=0).detach()+ I_z.detach()*self.T
            x_s = x_reshaped.sum(dim=0).detach().add(round_zero_point.mul(scale))
            max_val = self.qmax * scale
            min_val = torch.zeros_like(max_val)
            x_s = torch.clamp(x_s, min_val, max_val)
            x_tmp = x_s / self.T
        for t in range(self.T):
            # max_val = self.L * scale
            # min_val = torch.zeros_like(max_val)
            # x_tmp = torch.clamp(x_reshaped[t, ...].add(I_z), min_val, max_val)
            # x_tmp = torch.clamp(x_tmp, min_val, max_val)
            if not self.avg:
                x_tmp = x_reshaped[t, ...].add(I_z)
            self.v = self.v.detach().add(x_tmp.detach())
            # self.v = self.v.add(I_z)
            output = self.act(self.v, scale, self.L)
            self.v -= output.detach()
            output = output.detach().sub(I_z.detach())
            # self.v = self.v.clamp(min=min_val, max=max_val-1e-4)
            # self.v = self.v.clamp(0, self.L*scale-1e-4)
            spike_pot.append(output)
        spike_pot = torch.stack(spike_pot, dim=0)
        if self.group_size:
            spike_pot = spike_pot.reshape(T, bs, n, dim1)


        return spike_pot
    
class DTIFNeuron(LMHTNeuron):
    def forward(self, x: torch.Tensor,y :torch.Tensor):
        """
        x: (TB, lenSeq, dim)
        """
        
        scale = y /(self.L)
        scale1 = 1/self.L
        if self.asym or not self.disable_zero_point_in_sym:
            round_zero_point = clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) if self.zero_point is not None else None
        T, bs, n, dim1 = x.shape
        x_reshaped = x.reshape( T, bs, n, -1, self.group_size)
        self.v = torch.ones_like(x_reshaped[ 0, ...]).mul(scale)*0.5
        I_z = round_zero_point / self.T
        I_z = I_z.mul(scale)
        spike_pot = []
        if self.avg: 
            x_s = x_reshaped.sum(dim=0).detach().add(round_zero_point.mul(scale))
            max_val = self.qmax * scale
            min_val = torch.zeros_like(max_val)
            x_s = torch.clamp(x_s, min_val, max_val)
            x_tmp = x_s / self.T
        for t in range(self.T):

            if not self.avg:
                x_tmp = x_reshaped[t, ...].add(I_z)
            self.v = self.v.detach().add(x_tmp.detach())
            output = self.act(self.v, scale1, self.L)
            self.v -= output.detach()
            output = output.detach().sub(I_z.detach())
            spike_pot.append(output)
        spike_pot = torch.stack(spike_pot, dim=0)
        if self.group_size:
            spike_pot = spike_pot.reshape(T, bs, n, dim1)


        return spike_pot
    
class DTDFNeuron(LMHTNeuron):
    def forward(self, x: torch.Tensor,y :torch.Tensor):
        
        scale = y /(self.L * self.T)
        scale1 = 1/(self.L * self.T)
        if self.asym or not self.disable_zero_point_in_sym:
            round_zero_point = clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) if self.zero_point is not None else None
        T, bs, n, dim1 = x.shape
        x_reshaped = x.reshape( T, bs, n, -1, self.group_size)
        self.v = torch.ones_like(x_reshaped[ 0, ...]).mul(scale)*0.5
        # v_th = scale/(self.T*self.L)
        I_z = round_zero_point / self.T
        I_z = I_z.mul(scale)
        # I_z = round_zero_point 
        spike_pot = []
        if self.avg: 
            # x_s = x_reshaped.sum(dim=0).detach()+ I_z.detach()*self.T
            x_s = x_reshaped.sum(dim=0).detach().add(round_zero_point.mul(scale))
            max_val = self.qmax * scale
            min_val = torch.zeros_like(max_val)
            x_s = torch.clamp(x_s, min_val, max_val)
            x_tmp = x_s / self.T
        for t in range(self.T):
            if not self.avg:
                x_tmp = x_reshaped[t, ...].add(I_z)
            self.v = self.v.detach().add(x_tmp.detach())

            output = self.act(self.v, scale1, self.L)
            self.v -= output.detach()
            output = output.detach().sub(I_z.detach())

            spike_pot.append(output)
        spike_pot = torch.stack(spike_pot, dim=0)
        if self.group_size:
            spike_pot = spike_pot.reshape(T, bs, n, dim1)



        return spike_pot