import torch
import torch.nn as nn
import torch.nn.functional as F

from numpy import linspace

import operator
from functools import reduce


####### Functions #######

def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, 
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c


def activation_function(name):
    if name == 'relu':
        return torch.nn.ReLU()
    elif name == 'gelu':
        return torch.nn.GELU()
    elif name == 'tanh':
        return torch.nn.Tanh()
    else:
        print("You should add the activation function in the utils.")


class Complex_Activation(nn.Module):
    def __init__(self, act = 'relu', use_phase=False):
        super(Complex_Activation, self).__init__()
        self.act = activation_function(act)
        self.use_phase = use_phase

        self.b = nn.Parameter(torch.tensor([0.1]))
        self.b.requiresGrad = True

    def forward(self, z):
        if self.use_phase:
            return self.act(torch.abs(z) + self.b) * torch.exp(1.j * torch.angle(z)) 
        else:
            return self.act(z.real) + 1.j * self.act(z.imag)


def complex_multiplication_3d_FC(input, weights):
    # (batch, in_channel, x, y, z), (in_channel, out_channel, x, y, z) -> (batch, out_channel, x, y, z)
    # print(input.shape, weights.shape)
    return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

def complex_multiplication_3d_Diag(input, weights):
    # Perform element-wise multiplication using torch's broadcasting (Diag matrix-vector multiplication)
    return input * weights.unsqueeze(0)


def add_diag_3d(tensor_FC, tensor_diag, n_hidden):
    for i in range(n_hidden):
        tensor_FC[i, i, :, :, :] += tensor_diag[i, :, :, :]
    return tensor_FC


class MLP(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias):
        super(MLP, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias)

    def forward(self, x):

        x = x.permute(0, 2, 3, 4, 1)

        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)

        x = x.permute(0, 4, 1, 2, 3)

        return x
    
class MLP_Complex(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False):
        super(MLP_Complex, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias, dtype=torch.cfloat)
        self.complex_activation = Complex_Activation('gelu')
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias, dtype=torch.cfloat)


    def forward(self, x):

        x = x.permute(0, 2, 3, 4, 1)

        x = self.mlp1(x)
        x = self.complex_activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 4, 1, 2, 3)

        return x
    

################################################################
# Position/Momentum Layers
################################################################

class Bi_Momentum_Evolution_3d_Spatial_Low_Frequency(nn.Module):
    def __init__(self, n_hidden, modes_1, modes_2, modes_3, module_list=['FC', 'Diag']):
        super(Bi_Momentum_Evolution_3d_Spatial_Low_Frequency, self).__init__()

        self.modes_1 = modes_1
        self.modes_2 = modes_2
        self.modes_3 = modes_3

        self.n_hidden = n_hidden // 2 + 1

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)

        if self.module_FC and self.module_Diag:
            self.scale = 1 / (self.n_hidden * self.n_hidden)
            # Initialize 8 combined weights
            self.weights_1 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_2 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_3 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_4 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_5 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_6 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_7 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_8 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat),
                                          torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))

        else:
            if self.module_FC:
                self.scale_FC = 1 / (self.n_hidden * self.n_hidden)
                # Initialize 8 FC weights
                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_3 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_4 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_5 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_6 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_7 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_8 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))

            if self.module_Diag:
                self.scale_Diag = 1 / self.n_hidden
                # Initialize 8 Diag weights
                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_3_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_4_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_5_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_6_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_7_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_8_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.fftn(x, dim=[-3, -2, -1])

        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-3), x.size(-2), x.size(-1), dtype=torch.cfloat, device=x.device)

        if self.module_FC and self.module_Diag:
            # Apply each of the 8 combined weights to their respective regions
            out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1)
            out_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:], self.weights_2)
            out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3)
            out_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:], self.weights_4)
            out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_5)
            out_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:], self.weights_6)
            out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_7)
            out_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:], self.weights_8)

        else:
            if self.module_Diag:
                # Apply each of the 8 Diag weights to their respective regions
                out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] = complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1_diag)
                out_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:], self.weights_2_diag)
                out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3_diag)
                out_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:], self.weights_4_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] = complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_5_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:], self.weights_6_diag)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_7_diag)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:], self.weights_8_diag)

            if self.module_FC:
                # Apply each of the 8 FC weights to their respective regions
                out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1)
                out_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, -self.modes_3:], self.weights_2)
                out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3)
                out_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, -self.modes_3:], self.weights_4)
                out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_5)
                out_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, -self.modes_3:], self.weights_6)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_7)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, -self.modes_3:], self.weights_8)

        x = torch.fft.ifftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x


class Spatial_Positional_Evolution_3d_High_Frequency(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, n_dim=3, groups=1, padding='periodic'):
        super(Spatial_Positional_Evolution_3d_High_Frequency, self).__init__()

        self.n_dim = n_dim

        assert kernel_size % 2 == 1, "Kernel size should be odd"
        self.kernel_size = kernel_size
        self.padding_mode = self.get_padding_mode(padding)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding='same', padding_mode=self.padding_mode, bias=False, groups=groups)

    def get_padding_mode(self, padding):
        if padding == 'periodic':
            return 'circular'
        elif padding == 'replicate':
            return 'replicate'
        elif padding == 'reflect':
            return 'reflect'
        elif padding == 'zeros':
            return 'zeros'
        else:
            raise NotImplementedError("Desired padding mode is not currently supported")

    def forward(self, x, grid_width):
        conv = self.conv(x)
        conv_sum = torch.sum(self.conv.weight, dim=(2, 3, 4), keepdim=True)
        conv_sum = F.conv3d(x, conv_sum, groups=self.conv.groups)
        return (conv - conv_sum) / grid_width


class Hidden_Position_Space_Momentum_3d_Evolution_Spatial_Low_Frequency(nn.Module):
    def __init__(self, n_hidden, modes_1, modes_2, modes_3, module_list=['FC', 'Diag']):
        super(Hidden_Position_Space_Momentum_3d_Evolution_Spatial_Low_Frequency, self).__init__()

        self.modes_1 = modes_1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes_2 = modes_2
        self.modes_3 = modes_3

        self.n_hidden = n_hidden

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)

        if self.module_FC and self.module_Diag:
            self.scale = 1 / (self.n_hidden * self.n_hidden)

            self.weights_1 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_2 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_3 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))
            self.weights_4 = nn.Parameter(self.scale * add_diag_3d(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat), self.n_hidden))

        else:
            if self.module_FC:
                self.scale_FC = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_3 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_4 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))

            if self.module_Diag:
                self.scale_Diag = 1 / self.n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_3_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))
                self.weights_4_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, self.modes_3, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]

        x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1])

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)

        if self.module_FC and self.module_Diag:
            out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1)
            out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_2)
            out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3)
            out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] = complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_4)

        else:
            if self.module_Diag:
                out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] = complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] += complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_2_diag)
                out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] += complex_multiplication_3d_Diag(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3_diag)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] += complex_multiplication_3d_Diag(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_4_diag)

            if self.module_FC:
                out_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3] += complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, :self.modes_2, :self.modes_3], self.weights_1)
                out_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3] += complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2, :self.modes_3], self.weights_2)
                out_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3] += complex_multiplication_3d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:, :self.modes_3], self.weights_3)
                out_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3] += complex_multiplication_3d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:, :self.modes_3], self.weights_4)

        # Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x



class Recover_to_Sol_Space(nn.Module):

    '''
    Performing Integration back to the 2D dimension
    '''

    def __init__(self, in_channels, out_channels, mid_channels1, mid_channels2, bias = False):

        super(Recover_to_Sol_Space, self).__init__()

        self.linear_layer_0 = nn.Linear(in_channels, mid_channels1, bias=bias)
        self.linear_layer_1 = nn.Linear(mid_channels1, mid_channels2, bias=bias)
        self.linear_layer_2 = nn.Linear(mid_channels2, out_channels, bias=bias)
        self.dropout = nn.Dropout(0.1)      

    def forward(self, x):

        x = self.linear_layer_0(x)
        x = F.gelu(x)
        # x = self.dropout(x)
        x = self.linear_layer_1(x)
        x = F.gelu(x)
        x = self.linear_layer_2(x)
        x = self.dropout(x)
        return x


class Schroedinger_Evolution_Layer_3d_Decompose(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, spatial_modes_3, n_hidden, last_layer = False):
        
        super(Schroedinger_Evolution_Layer_3d_Decompose, self).__init__()

        self.n_hidden = n_hidden
        self.last_layer = last_layer

        self.hidden_momentum_spatial_momentum_evolution_layer = \
            Bi_Momentum_Evolution_3d_Spatial_Low_Frequency(self.n_hidden, spatial_modes_1, spatial_modes_2, spatial_modes_3, module_list=['FC', 'Diag'])
   
        self.hidden_position_spatial_momentum_evolution_layer = \
            Hidden_Position_Space_Momentum_3d_Evolution_Spatial_Low_Frequency(self.n_hidden, spatial_modes_1, spatial_modes_2, spatial_modes_3, module_list=['FC', 'Diag'])

        self.complex_mlp = MLP_Complex(n_hidden // 2 + 1, n_hidden // 2 + 1, (n_hidden // 2 + 1) * 2, bias=False)

        self.mlp1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp2 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp3 = MLP(n_hidden, n_hidden, n_hidden, bias=False)

        self.residual_layer = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

    def forward(self, x):

        x_res = self.mlp3(x)

        x_ft_on_p = torch.fft.rfft(x, dim = -4)
        x1 = self.hidden_momentum_spatial_momentum_evolution_layer(x_ft_on_p)
        x2 = self.complex_mlp(x_ft_on_p)
        x_ft_on_p = x1 + x2
        x12 = torch.fft.irfft(x_ft_on_p, dim = -4, n = self.n_hidden)

        x12 = self.mlp1(x12)

        x3 = self.hidden_position_spatial_momentum_evolution_layer(x)

        x4 = self.residual_layer(x)

        x5 = x + x12 + x3 + x4

        if not self.last_layer:

            x5 = self.mlp2(x5)

        x5 = x5 + x_res

        return x + x5

    

class Schroedinger_NO_3d_time_Diff(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, spatial_modes_3, n_hidden, n_layers = 4):

        super(Schroedinger_NO_3d_time_Diff, self).__init__()

        self.lift_layer = nn.Linear(7, n_hidden, bias=False)


        self.blocks = nn.ModuleList([Schroedinger_Evolution_Layer_3d_Decompose(spatial_modes_1,
                                                                     spatial_modes_2,
                                                                     spatial_modes_3,
                                                                     n_hidden,
                                                                     last_layer=(_ == n_layers - 1))
                                                                     for _ in range(n_layers)])

        self.proj_layer = nn.Linear(n_hidden, 4, bias=False)

        self.diff_layer_1 = Spatial_Positional_Evolution_3d_High_Frequency(n_hidden, n_hidden)
        self.res_layer_1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.res_layer_2 = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.lift_layer(x)
        x = x.permute(0, 4, 1, 2, 3)

        for block in self.blocks:
            x = block(x)

        x = x + self.res_layer_1(x) + F.tanh(x + self.diff_layer_1(x, 0.5/x.shape[-1]) + self.res_layer_2(x))

        x = x.permute(0, 2, 3, 4, 1)
        x = self.proj_layer(x)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)

