import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.components.layers.spectral import SpectralConv2d


def get_grid(batchsize, size_x, size_y, device):
    gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
    gridx = gridx.reshape(1, 1, size_x, 1).repeat([batchsize, 1, 1, size_y])
    gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
    gridy = gridy.reshape(1, 1, 1, size_y).repeat([batchsize, 1, size_x, 1])
    return torch.cat((gridx, gridy), dim=1).to(device)


class FNO2d(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 n_layers=4,
                 modes1=12,
                 modes2=12,
                 width=10,
                 latent_dim=128,
                 act='gelu',
                 **unused_kwargs):
        super().__init__()
        self.n_layers = n_layers
        self.act = getattr(F, act)
        self.layers = nn.ModuleDict({
            'fc0': nn.Conv2d((in_channels + 2), width, 1),
            'fc1': nn.Conv2d(width, latent_dim, 1),
            'fc2': nn.Conv2d(latent_dim, out_channels, 1),
        })
        for i in range(n_layers):
            self.layers.update({f'conv{i}': SpectralConv2d(width, width, modes1, modes2)})
            self.layers.update({f'w{i}': nn.Conv2d(width, width, 1)})

    def forward(self, x):
        batchsize = x.shape[0]
        size_x, size_y = x.shape[-2], x.shape[-1]
        grid = get_grid(batchsize, size_x, size_y, x.device)
        x = torch.cat((x, grid), dim=1)

        # Lift with P
        x = self.layers['fc0'](x)

        # Loop over the layers
        for i in range(self.n_layers):  # 4 layers
            x1 = self.layers[f'conv{i}'](x)
            x2 = self.layers[f'w{i}'](x)
            if i < self.n_layers - 1:  # don't apply activation to last layer
                x = self.act(x1 + x2)
            else:
                x = x1 + x2

        # Projection with Q
        x = self.act(self.layers['fc1'](x))
        x = self.layers['fc2'](x)
        return x


if __name__ == '__main__':
    model = FNO2d(3, 3, 12, 12, 10, 1)
    x = torch.rand(64, 3, 32, 32)
    y = model(x)
    print(y.shape)
