import torch
import torch.nn as nn
from models.fno import SpectralConv1d, SpectralConv2d, MLP1d, MLP2d
import torch.nn.functional as F
from models.custom_layers import get_residual_layer, act_registry

from functools import partial

class FNO1dBlock(nn.Module):
    def __init__(self, d_model, n_layers=1, modes=12, spectral_type='full', activation = 'identity', **kwargs):
        super(FNO1dBlock, self).__init__()
        # 2 layers because FNO1d.fno_layers has n_layers -1 (TODO: change this)

        self.modes = modes
        self.d_model = d_model
        self.padding = 2 # pad the domain if input is non-periodic
        # input channel is 2: (a(x), x)

        self.fno_layers = nn.ModuleList()
        self.ws = nn.ModuleList()
        self.mlps = nn.ModuleList()
        self.residual_layers = nn.ModuleList()

        spectral_layer = partial(SpectralConv1d, self.d_model, self.d_model, self.modes, spectral_type)

        self.activation_fn = act_registry.get(activation, F.gelu)

        for _ in range(n_layers):
            self.fno_layers.append(
                spectral_layer())
            self.ws.append(nn.Conv1d(d_model, d_model, kernel_size=1))
            self.mlps.append(MLP1d(self.d_model, self.d_model, self.d_model))
            self.residual_layers.append(get_residual_layer("zero", d_model)) # residual will be applied in S4BaseModel

    def forward(self, x, batch_dt = None):
        x = x.permute(0, 2, 1)
        # x = F.pad(x, [0, self.padding]) # pad the domain if input is non-periodic
        n = len(self.fno_layers)
        for i, (layer, w, mlp, res) in enumerate(zip(self.fno_layers, self.ws, self.mlps, self.residual_layers)):
            x_ = x
            x1 = layer(x)
            x1 = mlp(x1)
            x2 = w(x)
            x = x1 + x2
            x = self.activation_fn(x)
            x = x + res(x_)
        x = x.permute(0, 2, 1)
        return x, None


class FNO2dBlock(nn.Module):
    def __init__(
            self, 
            d_model,
            modes=12, 
            modes2=12, 
            n_layers=1,
            **kwargs):
        super(FNO2dBlock, self).__init__()

        self.modes = modes
        self.modes2 = modes2
        self.d_model = d_model
        # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)

        self.fno_layers = nn.ModuleList()
        self.ws = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for _ in range(n_layers):
            self.fno_layers.append(
            SpectralConv2d(self.d_model, self.d_model, self.modes, self.modes2))
            self.ws.append(nn.Conv2d(self.d_model, self.d_model, 1))
            self.mlps.append(MLP2d(self.d_model, self.d_model, self.d_model))

    def forward(self, x, batch_dt = None):
        # x dim = [b, x1, x2, t*v]
        x = x.permute(0, 3, 1, 2)
        
        # # Pad tensor with boundary condition
        # x = F.pad(x, [0, self.padding, 0, self.padding])

        for layer, w, mlp in zip(self.fno_layers, self.ws, self.mlps):
            x1 = layer(x)
            x1 = mlp(x1)
            x2 = w(x)
            x = x1 + x2
            x = F.gelu(x)
        x = x.permute(0, 2, 3, 1)
        
        return x, None
