import torch
import torch.nn as nn

class SpectralKernel1dTimeParam(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        """
        Spectral FNO Kernel. 
        """
        super(SpectralKernel1dTimeParam, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1*2-1, dtype=torch.float))
    
    # Complex multiplication
    def forward(self, x):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        weights1_f = torch.fft.rfft(self.weights1)
        return torch.einsum("bix,iox->box", x, weights1_f)