import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Basic import MLP
from layers.Embedding import timestep_embedding, unified_pos_embedding
from layers.FFNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d

from baselines.model_factory import F_FNOConfig

BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]
ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]


class Model(nn.Module):
    def __init__(self, model_cfg: F_FNOConfig):
        super(Model, self).__init__()
        self.__name__ = 'F_FNO'
        self.cfg = model_cfg
        self.spatial_dim = len(self.cfg.shapelist)
        ## embedding
        if self.cfg.pos_emb:
            pos = unified_pos_embedding(self.cfg.shapelist, self.cfg.ref, device=self.cfg.device)  # [1,N,ref^d]
            self.register_buffer('pos', pos, persistent=False)
            self.preprocess = MLP(self.cfg.in_channels + self.cfg.ref ** self.spatial_dim, self.cfg.hidden_channels * 2,
                                self.cfg.hidden_channels, n_layers=0, res=False, act=self.cfg.activation)
        else:
            self.preprocess = MLP(self.cfg.in_channels + self.spatial_dim, self.cfg.hidden_channels * 2, self.cfg.hidden_channels,
                                n_layers=0, res=False, act=self.cfg.activation)
        if self.cfg.time_input:
            self.time_fc = nn.Sequential(nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels), nn.SiLU(),
                                        nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels))

        self.padding = [(16 - size % 16) % 16 for size in self.cfg.shapelist]

        # support n_modes as int or tuple
        def modes_per_dim(n_modes, ndim):
            if isinstance(n_modes, int):
                return tuple([n_modes] * ndim)
            assert len(n_modes) == ndim, f"n_modes length must equal spatial dims ({ndim})"
            return tuple(n_modes)

        _modes = modes_per_dim(self.cfg.n_modes, len(self.padding))

        self.spectral_layers = nn.ModuleList([])
        for _ in range(self.cfg.n_layers):
            self.spectral_layers.append(BlockList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels,
                                                                     *_modes))
        # projectors
        self.fc1 = nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels)
        self.fc2 = nn.Linear(self.cfg.hidden_channels, self.cfg.out_channels)


    def forward(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # x: [B, N_points, spatial_dim], fx: [B, N_points, in_channels]
        B, N, _ = fx.shape
        if x is None:
            assert self.cfg.pos_emb
        if x is None or self.cfg.pos_emb:
            x = self.pos.expand(B, -1, -1)    # [B, N_points, ref^d]
        fx = torch.cat((x, fx), -1)
        fx = self.preprocess(fx)    # [B, N_points, hidden_channels]

        # time embedding
        if self.cfg.time_input and (T is not None):
            Time_emb = timestep_embedding(T, self.cfg.hidden_channels).unsqueeze(1).expand(-1, fx.shape[1], -1)    # [B, N_points, hidden_channels]
            Time_emb = self.time_fc(Time_emb)
            fx = fx + Time_emb
        x = fx.permute(0, 2, 1).reshape(B, self.cfg.hidden_channels, *self.cfg.shapelist)
        if not all(item == 0 for item in self.padding):
            if self.spatial_dim == 2:
                x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
            elif self.spatial_dim == 3:
                x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])

        for i in range(self.cfg.n_layers):
            x = x + self.spectral_layers[i](x)

        if not all(item == 0 for item in self.padding):
            if len(self.cfg.shapelist) == 2:
                x = x[..., :-self.padding[0], :-self.padding[1]]
            elif len(self.cfg.shapelist) == 3:
                x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
        x = x.reshape(B, self.cfg.hidden_channels, -1).permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
