from .new_fno import FNOBlocks1, SpectralConvAttn2d, SpectralConv
from ..layers.mlp import MLP
from ..layers.skip_connections import skip_connection
import torch.nn.functional as F

import torch.nn as nn

class AttnFNOtest(nn.Module):
    def __init__(self, n_modes, hidden_channels, in_channels=3, out_channels=1, 
                 dk=0, n_heads=2,
                 fno_skip='soft-gating'
                 ):
        super(AttnFNOtest, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_modes = n_modes
        self.hidden_channels = hidden_channels
        self.n_dim=2
        self.convs = SpectralConv(
            self.hidden_channels,
            self.hidden_channels,
            self.n_modes,
            output_scaling_factor=None,
            incremental_n_modes=None,
            rank=1.,
            fixed_rank_modes=False,
            separable=False,
            factorization=None,
            decomposition_kwargs=dict(),
            joint_factorization=False,
            n_layers=4,
        )
        self.attnconvs = SpectralConvAttn2d(
            self.hidden_channels,
            self.hidden_channels,
            self.n_modes,
            dk=dk,
            n_heads=n_heads,
            n_layers=2,
            )
        self.fno_skips = nn.ModuleList(
            [
                skip_connection(
                    self.hidden_channels,
                    self.hidden_channels,
                    skip_type=fno_skip,
                    n_dim=self.n_dim,
                )
                for _ in range(7)
            ]
        )
        self.local_convs = nn.ModuleList(
            [
                nn.Conv2d(self.hidden_channels, self.hidden_channels, 3, padding=1)
                for _ in range(4)
            ]
        )

        self.q0 = nn.Conv2d(self.in_channels, self.hidden_channels, kernel_size=1)
        self.p0 = nn.Conv2d(self.hidden_channels, self.out_channels, kernel_size=1)

    def forward(self, x, **kwargs):
        x = self.q0(x)
        i=0
        x_fno = self.convs(x, i)
        x_c = self.local_convs[i](x)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip+x_c
        x = F.gelu(x)
        i=1
        x_fno = self.convs(x, i)
        x_c = self.local_convs[i](x)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip+x_c
        x = F.gelu(x)

        i=0
        x_fno = self.attnconvs(x, i)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip
        x = F.gelu(x)
        i=1
        x_fno = self.attnconvs(x, i)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip
        x = F.gelu(x)

        i=2
        x_fno = self.convs(x, i)
        x_c = self.local_convs[i](x)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip+x_c
        x = F.gelu(x)
        i=3
        x_fno = self.convs(x, i)
        x_c = self.local_convs[i](x)
        x_skip = self.fno_skips[i](x)
        x = x_fno+x_skip+x_c
        x = F.gelu(x)

        x=self.p0(x)
        return x
        

        