import torch
from pina.model import FNO


output_channels = 5
batch_size = 15
resolution = [30, 40, 50]
lifting_dim = 128


def test_constructor():
    input_channels = 3
    lifting_net = torch.nn.Linear(input_channels, lifting_dim)
    projecting_net  = torch.nn.Linear(60, output_channels)

    # simple constructor
    FNO(lifting_net=lifting_net,
        projecting_net=projecting_net,
        n_modes=5,
        dimensions=3,
        inner_size=60,
        n_layers=5)
    
    # simple constructor with n_modes list
    FNO(lifting_net=lifting_net,
        projecting_net=projecting_net,
        n_modes=[5, 3, 2],
        dimensions=3,
        inner_size=60,
        n_layers=5)

    # simple constructor with n_modes list of list
    FNO(lifting_net=lifting_net,
        projecting_net=projecting_net,
        n_modes=[[5, 3, 2], [5, 3, 2]],
        dimensions=3,
        inner_size=60,
        n_layers=2)
    
    # simple constructor with n_modes list of list
    projecting_net  = torch.nn.Linear(50, output_channels)
    FNO(lifting_net=lifting_net,
        projecting_net=projecting_net,
        n_modes=5,
        dimensions=3,
        layers=[50, 50])
    
def test_1d_forward():
    input_channels = 1
    input_ = torch.rand(batch_size, resolution[0], input_channels)
    lifting_net = torch.nn.Linear(input_channels, lifting_dim)
    projecting_net = torch.nn.Linear(60, output_channels)
    fno = FNO(lifting_net=lifting_net,
            projecting_net=projecting_net,
            n_modes=5,
            dimensions=1,
            inner_size=60,
            n_layers=2)
    out = fno(input_)
    assert out.shape == torch.Size([batch_size, resolution[0], output_channels])

def test_2d_forward():
    input_channels = 2
    input_ = torch.rand(batch_size, resolution[0], resolution[1], input_channels)
    lifting_net = torch.nn.Linear(input_channels, lifting_dim)
    projecting_net = torch.nn.Linear(60, output_channels)
    fno = FNO(lifting_net=lifting_net,
            projecting_net=projecting_net,
            n_modes=5,
            dimensions=2,
            inner_size=60,
            n_layers=2)
    out = fno(input_)
    assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], output_channels])

def test_3d_forward():
    input_channels = 3
    input_ = torch.rand(batch_size, resolution[0], resolution[1], resolution[2], input_channels)
    lifting_net = torch.nn.Linear(input_channels, lifting_dim)
    projecting_net = torch.nn.Linear(60, output_channels)
    fno = FNO(lifting_net=lifting_net,
            projecting_net=projecting_net,
            n_modes=5,
            dimensions=3,
            inner_size=60,
            n_layers=2)
    out = fno(input_)
    assert out.shape == torch.Size([batch_size, resolution[0], resolution[1], resolution[2], output_channels])