import torch
import torch.nn as nn

class SpectralKernel1dDiag(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):

        """
        1D Fourier Diagonal layer. It does FFT, linear transform, and Inverse FFT.    
        """
        super(SpectralKernel1dDiag, self).__init__()
        self.in_channels = in_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = 1.0
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, self.modes1, dtype=torch.cfloat))

    # Complex multiplication
    def forward(self, x):
        # (batch, in_channel, x ), (in_channel, x) -> (batch, in_channel, x)
        # no channel mixing, only multiply same channels
        return x*self.weights1