import torch
import torch.nn as nn
from typing import Callable, List

def mul(signal, kernel):
    out = torch.einsum("bxyi,xyio->bxyo", signal, kernel)
    return out

class SpectralConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        self.scale = (1 / (in_channels * out_channels))
        self.kernels = nn.Parameter(self.scale * torch.rand(2, self.modes, self.modes, in_channels, out_channels, dtype=torch.cfloat))

    def forward(self, x):
        x_hat = torch.fft.rfft2(x, dim=(-3, -2))
        out_ft = torch.zeros(*x_hat.shape[:-1], self.out_channels, dtype=torch.cfloat, device=x.device)
        out_ft[:, :self.modes, :self.modes, :] = mul(x_hat[:, :self.modes, :self.modes, :], self.kernels[0])
        out_ft[:, -self.modes:, :self.modes, :] = mul(x_hat[:, -self.modes:, :self.modes, :], self.kernels[1])
        out = torch.fft.irfft2(out_ft, s=(x.size(-3), x.size(-2)), dim=(-3, -2))

        return out

class FNOBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, modes: int, act: Callable = nn.ReLU()):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes
        self.act = act
        self.spectral_conv = SpectralConv2D(in_channels=self.in_channels, out_channels=self.out_channels, modes=self.modes)
        self.bypass_conv = nn.Linear(self.in_channels, self.out_channels)

    def forward(self, x):
        spectral_out = self.spectral_conv(x)
        x = x.permute(0, 3, 1, 2)
        bypass_out = self.bypass_conv(x)
        out = self.act(spectral_out + bypass_out)
        return out

class FNO(nn.Module):
    def __init__(
            self,
            in_channels: int,
            hidden_channels: List[int],
            out_channels: int,
            modes: int):
        super().__init__()
        self.modes = modes
        self.encoder = nn.Linear(in_channels, hidden_channels[0])

        self.blocks = nn.ModuleList()
        for in_channels_f, out_channels_f in zip(hidden_channels[:-1], hidden_channels[1:]):
            self.blocks.append(
                FNOBlock(
                    in_channels=in_channels_f,
                    out_channels=out_channels_f,
                    modes=self.modes,
                )
            )

        self.decoder_velx = nn.Linear(hidden_channels[-1], out_channels)
        self.decoder_vely = nn.Linear(hidden_channels[-1], out_channels)

    def forward(self, x):
        x = self.encoder(x)
        for fno in self.blocks:
            x = fno(x)

        velx = self.decoder_velx(x)
        vely = self.decoder_vely(x)

        out = torch.cat((velx[..., None], vely[..., None]), dim=-1)

        return out