import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from spikingjelly.activation_based import neuron, functional, surrogate, layer, base

from torch import Tensor
from typing import Optional, List, Tuple, Union, Callable


@torch.jit.script
def round_to_pow2(x: torch.Tensor):
    sign = torch.sign(x)

    p = torch.log2(x.abs())
    p_round = torch.round(p)
    return sign * torch.pow(2., p_round)


class RoundToPow2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        return round_to_pow2(x)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        return grad_output


class BN1d(nn.Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features, **factory_kwargs))
        self.bias = nn.Parameter(torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))

    def get_mean_var(self, input: Tensor=None) -> Tuple[Tensor, Tensor]:
        if self.training:
            mean = input.mean([0, 2])
            var = input.var([0, 2], unbiased=False)
            with torch.no_grad():
                self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1. - self.momentum) * self.running_var + self.momentum * input.var([0, 2], unbiased=True)
        else:
            mean, var = self.running_mean, self.running_var

        return mean, var


class mul_free_channel_wise_psn(nn.Module):
    def __init__(self, C: int, K: int, surrogate_function, dilation):
        super().__init__()
        self.C = C
        self.K = K
        self.surrogate_function = surrogate_function
        self.conv1d = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=K, groups=C, dilation=dilation, bias=False)
        self.dilation = dilation
        self.bn1d = BN1d(C)
        nn.init.constant_(self.bn1d.bias, -1.)

    def forward(self, x_seq: torch.Tensor):
        shape = x_seq.shape
        if shape.__len__() == 5: # T, N, C, H, W -> [N, H, W, C, T]
            x_seq = x_seq.permute(1, 3, 4, 2, 0).reshape(-1, shape[2], shape[0])
        elif shape.__len__() == 4: # T, N, C, L -> [N, L, C, T]
            x_seq = x_seq.permute(1, 3, 2, 0).reshape(-1, shape[2], shape[0])
        else: # T, N, C
            x_seq = x_seq.permute(1, 2, 0).reshape(-1, shape[2], shape[0])
            
        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        if self.training: 
            t = self.conv1d(x_pad) # [N, C, T]
            mean, var = self.bn1d.get_mean_var(t)
        else:
            mean, var = self.bn1d.get_mean_var()

        weight = (self.conv1d.weight.transpose(0, 2) * self.bn1d.weight / (var + self.bn1d.eps).sqrt()).transpose(0, 2)
        bias = self.bn1d.bias - mean * self.bn1d.weight / (var + self.bn1d.eps).sqrt()
        qweight = RoundToPow2.apply(weight)

        v_seq = F.conv1d(x_pad, qweight, bias=bias, stride=self.conv1d.stride, padding=self.conv1d.padding, dilation=self.conv1d.dilation, groups=self.conv1d.groups)
        s_seq = self.surrogate_function(v_seq)

        if shape.__len__() == 5: # T, N, C, H, W
            s_seq = s_seq.view(shape[1], shape[3], shape[4], shape[2], shape[0]).permute(4, 0, 3, 1, 2).contiguous()
        elif shape.__len__() == 4: #  T, N, C,L
            s_seq = s_seq.view(shape[1], shape[3], shape[2], shape[0]).permute(3, 0, 2, 1).contiguous()
        else: # T, N, C
            s_seq = s_seq.view(shape[1], shape[2], shape[0]).permute(2, 0, 1).contiguous()
        return s_seq
