import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from layers.Basic import MLP
from layers.Embedding import timestep_embedding, unified_pos_embedding
from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d
from neuralop.models import FNO as NeuralopFNO

from baselines.model_factory import FNOConfig

BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]
ConvList = [None, nn.Conv1d, nn.Conv2d, nn.Conv3d]



class Model(nn.Module):
    """
    Unified FNO model with two backends:
      - 'native'   : local spectral blocks
      - 'neuralop' : delegate forward to neuralop FNO
    Public I/O stays point-wise: fx:[B,N,Cin], x:[B,N,d] -> [B,N,Cout]
    """
    def __init__(self, model_cfg: FNOConfig):
        super(Model, self).__init__()
        self.__name__ = 'FNO'
        self.cfg = model_cfg
        self.spatial_dim = len(self.cfg.shapelist)
        self.backend = model_cfg.backend 

        if self.backend == 'native':
            # modules
            ## embedding
            if self.cfg.pos_emb:
                pos = unified_pos_embedding(self.cfg.shapelist, self.cfg.ref, device=self.cfg.device)  # [1,N,ref^d]
                self.register_buffer('pos', pos, persistent=False)
                self.preprocess = MLP(self.cfg.in_channels + self.cfg.ref ** self.spatial_dim, self.cfg.hidden_channels * 2,
                                    self.cfg.hidden_channels, n_layers=0, res=False, act=self.cfg.activation)
            else:
                self.preprocess = MLP(self.cfg.in_channels + self.spatial_dim, self.cfg.hidden_channels * 2, self.cfg.hidden_channels,
                                    n_layers=0, res=False, act=self.cfg.activation)
            if self.cfg.time_input:
                self.time_fc = nn.Sequential(nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels), nn.SiLU(),
                                            nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels))

            self.padding = [(16 - size % 16) % 16 for size in self.cfg.shapelist]

            # support n_modes as int or tuple
            def modes_per_dim(n_modes, ndim):
                if isinstance(n_modes, int):
                    return tuple([n_modes] * ndim)
                assert len(n_modes) == ndim, f"n_modes length must equal spatial dims ({ndim})"
                return tuple(n_modes)

            _modes = modes_per_dim(self.cfg.n_modes, len(self.padding))
            # pick 1d/2d/3d spectral + pointwise convs
            Block = BlockList[len(self.padding)]
            self.conv0 = Block(self.cfg.hidden_channels, self.cfg.hidden_channels, *_modes)
            self.conv1 = Block(self.cfg.hidden_channels, self.cfg.hidden_channels, *_modes)
            self.conv2 = Block(self.cfg.hidden_channels, self.cfg.hidden_channels, *_modes)
            self.conv3 = Block(self.cfg.hidden_channels, self.cfg.hidden_channels, *_modes)

            self.w0 = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
            self.w1 = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
            self.w2 = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
            self.w3 = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
            # projectors
            self.fc1 = nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels)
            self.fc2 = nn.Linear(self.cfg.hidden_channels, self.cfg.out_channels)

        elif self.backend == 'neuralop':
            if isinstance(self.cfg.n_modes, int):
                n_modes = tuple([self.cfg.n_modes] * self.spatial_dim)
            else:
                n_modes = tuple(self.cfg.n_modes)
                assert len(n_modes) == self.spatial_dim, "n_modes length must match spatial dims"
            self.neuralop_fno = NeuralopFNO(
                n_modes=n_modes,
                in_channels=self.cfg.in_channels,
                out_channels=self.cfg.out_channels,
                hidden_channels=self.cfg.hidden_channels,
                n_layers=self.cfg.n_layers,
            )


    def _to_grid(self, pts: torch.Tensor, C: int) -> torch.Tensor:
        # pts: [B, N, C] -> [B, C, *shapelist]
        B = pts.size(0)
        return pts.permute(0, 2, 1).reshape(B, C, *self.cfg.shapelist)


    def _from_grid(self, grid: torch.Tensor) -> torch.Tensor:
        # grid: [B, C, *shapelist] -> [B, N, C]
        B, C = grid.size(0), grid.size(1)
        return grid.reshape(B, C, -1).permute(0, 2, 1)


    def forward_native(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # x: [B, N_points, spatial_dim], fx: [B, N_points, in_channels]
        B, N, _ = fx.shape
        if x is None:
            assert self.cfg.pos_emb
        if x is None or self.cfg.pos_emb:
            x = self.pos.expand(B, -1, -1)    # [B, N_points, ref^d]
        fx = torch.cat((x, fx), -1)
        fx = self.preprocess(fx)    # [B, N_points, hidden_channels]

        # time embedding
        if self.cfg.time_input and (T is not None):
            Time_emb = timestep_embedding(T, self.cfg.hidden_channels).unsqueeze(1).expand(-1, fx.shape[1], -1)    # [B, N_points, hidden_channels]
            Time_emb = self.time_fc(Time_emb)
            fx = fx + Time_emb
        x = fx.permute(0, 2, 1).reshape(B, self.cfg.hidden_channels, *self.cfg.shapelist)
        if not all(item == 0 for item in self.padding):
            if self.spatial_dim == 2:
                x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
            elif self.spatial_dim == 3:
                x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if not all(item == 0 for item in self.padding):
            if self.spatial_dim == 2:
                x = x[..., :-self.padding[0], :-self.padding[1]]
            elif self.spatial_dim == 3:
                x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
        x = x.reshape(B, self.cfg.hidden_channels, -1).permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)    # [B, N_points, out_channels]
        return x
    

    # ---------- neuralop backend ----------
    def forward_neuralop(self, fx: torch.Tensor) -> torch.Tensor:
        assert self.neuralop_fno is not None, "neuralop backend not initialized"
        B, N, Cin = fx.shape
        n_points = int(np.prod(self.cfg.shapelist))
        assert N == n_points, f"N={N} must equal prod(shapelist)={n_points}"

        x_grid = self._to_grid(fx, Cin)          # [B, Cin, *S]
        y_grid = self.neuralop_fno(x_grid)       # [B, Cout, *S]
        y = self._from_grid(y_grid)              # [B, N, Cout]
        return y
    

    # ---------- unified forward ----------
    def forward(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # fx: [B, N_points, in_channels]  -> [B, N_points, out_channels]
        if self.backend == 'neuralop':
            return self.forward_neuralop(fx)                 
        else:
            return self.forward_native(fx=fx, x=x, T=T)        

