import torch
import torch.nn as nn

class SpectralKernel1dHiddenOnly(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):

        """
        1D Fourier layer. It does FFT and Inverse FFT.   
        only hidden re-weighting, same for each frequency dimension 
        """
        super(SpectralKernel1dHiddenOnly, self).__init__()
        self.in_channels = in_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.weights1 = nn.Parameter(torch.rand(self.in_channels, dtype=torch.cfloat))

    # Complex multiplication
    def forward(self, x):
        return torch.einsum("bix,i->bix", x, self.weights1)