import torch
import torch.nn as nn
import torch.fft
from spikingjelly.activation_based import functional, layer, neuron
class DilatedDepthwiseConvBranch(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super(DilatedDepthwiseConvBranch, self).__init__()
        
        self.depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, padding=dilation * (kernel_size // 2), 
                                        dilation=dilation, groups=in_channels)
        self.pointwise_conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm1d(out_channels) 
        self.lif = neuron.ParametricLIFNode(step_mode='m', v_threshold=0.5, backend='cupy')

    def forward(self, x):
        T, N, C, V = x.size()
        x = x.flatten(0, 1) 
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        x = self.bn(x)
        x = x.reshape(T, N, -1, V).contiguous()
        x = self.lif(x).reshape(T, N, -1, V).contiguous()
        return x

class MultiScaleDilatedFourierFeatureExtractorplus(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilations=[1, 2, 3, 4], residual=True, residual_kernel_size=1, num_point=25):
        super(MultiScaleDilatedFourierFeatureExtractorplus, self).__init__()
        
       
        assert out_channels % (len(dilations)) == 0, '# out channels should be multiples of # branches'
        self.branch_out_channels = out_channels // (len(dilations))
        
       
        self.real_branches = nn.ModuleList([
            DilatedDepthwiseConvBranch(in_channels, self.branch_out_channels, kernel_size, dilation=d)
            for d in dilations
        ])
        self.imag_branches = nn.ModuleList([
            DilatedDepthwiseConvBranch(in_channels, self.branch_out_channels, kernel_size, dilation=d)
            for d in dilations
        ])
        
      
        self.hamming_window = nn.Parameter(torch.hamming_window(num_point, periodic=True))
        
        self.residual = residual
        if residual:
            if in_channels != out_channels:
                self.residual_conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
                self.residual_bn = nn.BatchNorm1d(out_channels)
            else:
                self.residual_conv = None
        self.lif = neuron.ParametricLIFNode(step_mode='m', v_threshold=0.5, backend='cupy')
    def forward(self, x):
       
        if self.residual and self.residual_conv:
            T, N, C, V = x.size()
            res = x.flatten(0, 1)  
            res = self.residual_conv(res)
            res = self.residual_bn(res)
            res = res.reshape(T, N, -1, V).contiguous()
            res = self.lif(x)
        else:
            res = x
        x = x * self.hamming_window  
       
        freq_tensor = torch.fft.fft(x, dim=-1)

    
        real_part = freq_tensor.real  # (T, N, C, V)
        imag_part = freq_tensor.imag  # (T, N, C, V)

        T, N, C, V = real_part.shape
        

        real_features = []
        imag_features = []

        for branch in self.real_branches:
            real_features.append(branch(real_part))
        for branch in self.imag_branches:
            imag_features.append(branch(imag_part))
        
        real_features = torch.cat(real_features, dim=2)  
        imag_features = torch.cat(imag_features, dim=2)  
        real_features = real_features.float()
        imag_features = imag_features.float()

        freq_features = torch.complex(real_features, imag_features)
        
 
        output_tensor = torch.fft.ifft(freq_features, dim=-1).real  
        
        
        output_tensor = output_tensor.reshape(T, N, -1, V).contiguous()  
        output_tensor =output_tensor+res
        return output_tensor


class DilatedConvBranch(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super(DilatedConvBranch, self).__init__()

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=dilation * (kernel_size // 2), dilation=dilation)
        self.bn = nn.BatchNorm1d(out_channels)
        self.lif = neuron.ParametricLIFNode(step_mode='m',v_threshold=0.5,backend='cupy')
    def forward(self, x):
        T,N,C,V = x.size()
        x = x.flatten(0,1)
        x = self.conv(x)
        x = self.bn(x)
        x = x.reshape(T, N, -1, V).contiguous()
        x = self.lif(x).reshape(T,N,-1, V).contiguous()
        return x

class MultiScaleDilatedFourierFeatureExtractor(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilations=[1, 2, 3, 4], residual=True, residual_kernel_size=1,num_point=25):
        super(MultiScaleDilatedFourierFeatureExtractor, self).__init__()
     
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'
        branch_out_channels = out_channels // (len(dilations) + 2)
        
     
        self.real_branches = DilatedConvBranch(in_channels, out_channels)
        self.imag_branches = DilatedConvBranch(in_channels, out_channels)
    
       

        self.hamming_window = nn.Parameter(torch.hamming_window(num_point, periodic=True))

    def forward(self, x):
        res = x
        x = x * self.hamming_window  
   
        freq_tensor = torch.fft.fft(x, dim=-1)

       
        real_part = freq_tensor.real  # (N, T, C, V)
        imag_part = freq_tensor.imag  # (N, T, C, V)

      
        T,N, C, V = real_part.shape
        real_features = self.real_branches(real_part)
        imag_features = self.imag_branches(imag_part)

   
        real_features = real_features.float()
        imag_features = imag_features.float()
        freq_features = torch.complex(real_features, imag_features)
        output_tensor = torch.fft.ifft(freq_features, dim=-1).real 
        output_tensor = output_tensor+res
        output_tensor = output_tensor.reshape(T, N, -1, V).contiguous()
        return output_tensor


def main():
  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    N, T, C, V = 8, 10, 64, 64  
    

    x = torch.rand(N, T, C, V).to(device)

    in_channels = C
    out_channels = C
    kernel_size = 3
    stride = 1
    dilations = [1, 2]
    residual = True
    residual_kernel_size = 1

  
    model = MultiScaleDilatedFourierFeatureExtractor(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        dilations=dilations,
        residual=residual,
        residual_kernel_size=residual_kernel_size
    ).to(device)


    print(model)
    
 
    output = model(x)


    
   

if __name__ == "__main__":
    main()