import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import triton
from spikingjelly.activation_based import neuron, cuda_utils
from spikingjelly.activation_based import layer, neuron, surrogate, functional, base
import triton.language as tl
from torch import Tensor
from typing import Tuple

@torch.jit.script
def round_to_pow2(x: torch.Tensor):
    sign = torch.sign(x)

    p = torch.log2(x.abs())
    p_round = torch.round(p)
    return sign * torch.pow(2., p_round)

class RoundToPow2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        return round_to_pow2(x)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        return grad_output

def roundtopow(x: torch.Tensor):
    return RoundToPow2.apply(x)

class BN1d(nn.Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features, **factory_kwargs))
        self.bias = nn.Parameter(torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))

    def get_mean_var(self, input: Tensor=None) -> Tuple[Tensor, Tensor]:
        if self.training:
            dims_to_sum = [i for i in range(input.dim()) if i != 1]
            mean = input.mean(dims_to_sum)
            var = input.var(dims_to_sum, unbiased=False)

            with torch.no_grad():
                self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1. - self.momentum) * self.running_var + self.momentum * input.var(dims_to_sum, unbiased=True)
        else:
            mean, var = self.running_mean, self.running_var

        return mean, var

class mul_free_channel_wise_psn(nn.Module):
    def __init__(self, C: int, K: int, surrogate_function, dilation, step_mode: str='m', backend='torch'):
        super().__init__()
        self.C = C
        self.K = K
        self.surrogate_function = surrogate_function
        self.conv1d = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=K, groups=C, dilation=dilation, bias=False)
        self.dilation = dilation
        self.step_mode = step_mode
        self.backend = backend
        self.bn1d = BN1d(C)
        nn.init.constant_(self.bn1d.bias, -1.)

    def forward(self, x_seq: torch.Tensor):
        shape = x_seq.shape
        if shape.__len__() == 5: # T, N, C, H, W -> [N, H, W, C, T]
            x_seq = x_seq.permute(1, 3, 4, 2, 0).reshape(-1, shape[2], shape[0])
        elif shape.__len__() == 4: # T, N, C, L -> [N, L, C, T]
            x_seq = x_seq.permute(1, 3, 2, 0).reshape(-1, shape[2], shape[0])
        else: # T, N, C
            x_seq = x_seq.permute(1, 2, 0).reshape(-1, shape[2], shape[0])
            
        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        if self.training:
            t = self.conv1d(x_pad)
            mean, var = self.bn1d.get_mean_var(t)
        else:
            mean, var = self.bn1d.get_mean_var()

        weight = (self.conv1d.weight.transpose(0, 2) * self.bn1d.weight / (var + self.bn1d.eps).sqrt()).transpose(0, 2)
        qweight = roundtopow(weight)
        bias = self.bn1d.bias - mean * self.bn1d.weight / (var + self.bn1d.eps).sqrt()

        h_seq = F.conv1d(x_pad, qweight, bias=bias, stride=self.conv1d.stride, padding=self.conv1d.padding, dilation=self.conv1d.dilation, groups=self.conv1d.groups)

        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5: # T, N, C, H, W
            s_seq = s_seq.view(shape[1], shape[3], shape[4], shape[2], shape[0]).permute(4, 0, 3, 1, 2).contiguous()
        elif shape.__len__() == 4:
            #  T, N, C,L
            s_seq = s_seq.view(shape[1], shape[3], shape[2], shape[0]).permute(3, 0, 2, 1).contiguous()
        else:
            # T, N, C
            s_seq = s_seq.view(shape[1], shape[2], shape[0]).permute(2, 0, 1).contiguous()
        return s_seq


