import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import linspace
from timm.models.layers import trunc_normal_

import operator
from functools import reduce

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, 
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c

def activation_function(name):
    if name == 'relu':
        return nn.ReLU()
    elif name == 'gelu':
        return nn.GELU()
    elif name == 'tanh':
        return nn.Tanh()
    elif name == 'silu':
        return nn.SiLU()
    else:
        print("You should add the activation function in the utils.")


class Complex_Activation(nn.Module):
    def __init__(self, act = 'gelu', use_phase=False):
        super(Complex_Activation, self).__init__()
        self.act = activation_function(act)
        self.use_phase = use_phase

        self.b = nn.Parameter(torch.tensor([0.1]))
        self.b.requiresGrad = True

    def forward(self, z):
        if self.use_phase:
            return self.act(torch.abs(z) + self.b) * torch.exp(1.j * torch.angle(z)) 
        else:
            return self.act(z.real) + 1.j * self.act(z.imag)


def complex_multiplication_2d_FC(input, weights):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    return torch.einsum("bixy,ioxy->boxy", input, weights)


def complex_multiplication_2d_Diag(input, weights):
    # Perform element-wise multiplication using torch's broadcasting (Diag matrix-vector multiplication)
    return input * weights.unsqueeze(0)


def add_diag(tensor_FC, tensor_diag, n_hidden):

    for i in range(n_hidden):
        tensor_FC[i, i, :, :] += tensor_diag[i, :, :]

    return tensor_FC


class MLP(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False, activation = 'gelu'):
        super(MLP, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias)

        self.activation = activation_function(activation)


    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.mlp1(x)
        x = self.activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 3, 1, 2)

        return x
    
class MLP_Complex(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False, activation = 'gelu'):
        super(MLP_Complex, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias, dtype=torch.cfloat)
        self.complex_activation = Complex_Activation(activation)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias, dtype=torch.cfloat)


    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.mlp1(x)
        x = self.complex_activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 3, 1, 2)

        return x

################################################################
# Position/Momentum Layers
################################################################

class Bi_Momentum_Evolution_2d_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, modes_2, module_list = ['FC', 'Diag']):
        super(Bi_Momentum_Evolution_2d_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1
        self.modes_2 = modes_2

        self.n_hidden = n_hidden//2 + 1

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_3 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_4 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))

        else:

            if self.module_FC:

                self.scale_FC = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_3 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_4 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / self.n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_3_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_4_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

    def forward(self, x):

        batchsize = x.shape[0]

        x_ft = torch.fft.fft2(x, dim = [-2, -1])

        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-2), x.size(-1), dtype=torch.cfloat, device=x.device)

        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
            out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2)
            out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3)
            out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2_diag)
                out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4_diag)

            if self.module_FC:
                            
                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
                out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2)
                out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4)

        x = torch.fft.ifft2(out_ft, s=(x.size(-2), x.size(-1)))

        return x

# cite the localized ddifferential kernal operator: https://arxiv.org/abs/2402.16845
class Spatial_Positional_Evolution_2d_High_Freqency(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, n_dim = 2, groups=1, padding='periodic'):
        super(Spatial_Positional_Evolution_2d_High_Freqency, self).__init__()

        self.n_dim = n_dim

        assert kernel_size % 2 == 1, "Kernel size should be odd"
        self.kernel_size = kernel_size
        self.padding_mode = self.get_padding_mode(padding)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same', padding_mode=self.padding_mode, bias=False, groups=groups)

    def get_padding_mode(self, padding):
        if padding == 'periodic':
            return 'circular'
        elif padding == 'replicate':
            return 'replicate'
        elif padding == 'reflect':
            return 'reflect'
        elif padding == 'zeros':
            return 'zeros'
        else:
            raise NotImplementedError("Desired padding mode is not currently supported")

    def forward(self, x, grid_width):
        conv = self.conv(x)
        conv_sum = torch.sum(self.conv.weight, dim=(2, 3), keepdim=True)
        conv_sum = F.conv2d(x, conv_sum, groups=self.conv.groups)
        return (conv - conv_sum) / grid_width


class Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, modes_2, module_list = ['FC', 'Diag']):
        super(Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1
        self.modes_2 = modes_2

        self.n_hidden = n_hidden

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))

        else:
            if self.module_FC:

                self.scale_FC = 1 / (n_hidden * n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]

        x_ft = torch.fft.rfft2(x)

        out_ft = torch.zeros(batchsize, self.n_hidden,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)


        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
            out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2_diag)

            if self.module_FC:
                            
                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2)

        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class Recover_to_Sol_Space(nn.Module):

    '''
    Performing Integration back to the 2D dimension
    '''

    def __init__(self, in_channels, out_channels, mid_channels1, mid_channels2, bias = False):

        super(Recover_to_Sol_Space, self).__init__()

        self.linear_layer_0 = nn.Linear(in_channels, mid_channels1, bias=bias)
        self.linear_layer_1 = nn.Linear(mid_channels1, mid_channels2, bias=bias)
        self.linear_layer_2 = nn.Linear(mid_channels2, out_channels, bias=bias)
        # self.dropout = nn.Dropout(0.1)      

    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.linear_layer_0(x)
        x = F.gelu(x)
        x = self.linear_layer_1(x)
        x = F.gelu(x)
        # x = self.dropout(x)
        x = self.linear_layer_2(x)

        return x


class Schroedinger_Evolution_Layer_2d_Decompose(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, n_hidden, last_layer = False):
        
        super(Schroedinger_Evolution_Layer_2d_Decompose, self).__init__()

        self.n_hidden = n_hidden
        self.last_layer = last_layer

        self.hidden_momentum_spatial_momentum_evolution_layer = \
            Bi_Momentum_Evolution_2d_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, spatial_modes_2, module_list=['FC'])
   
        self.hidden_position_spatial_momentum_evolution_layer = \
            Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, spatial_modes_2, module_list=['Diag'])

        self.complex_mlp_residual = MLP_Complex(n_hidden // 2 + 1, n_hidden // 2 + 1, (n_hidden // 2 + 1) * 2, bias=False)

        self.mlp1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp2 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp_res = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp3 = MLP(n_hidden, n_hidden, n_hidden, bias=False)

        self.residual_layer = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

    def forward(self, x):

        x_res = self.mlp_res(x)

        x_ft_on_p = torch.fft.rfft(x, dim = -3)
        x1 = self.hidden_momentum_spatial_momentum_evolution_layer(x_ft_on_p)
        x2 = self.complex_mlp_residual(x_ft_on_p)
        x_ft_on_p = x1 + x2
        x12 = torch.fft.irfft(x_ft_on_p, dim = -3, n = self.n_hidden)

        x12 = self.mlp1(x12)

        x3 = self.hidden_position_spatial_momentum_evolution_layer(x)

        x3 = self.mlp2(x3)

        x4 = self.residual_layer(x)

        x = x12 + x3 + x4

        if not self.last_layer:

            x = F.gelu(x)
            # x = self.mlp3(x)

        x = x + x_res

        return x
    


class Schroedinger_NO_2d_Decompose_Diff(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, n_hidden, n_layers = 4):

        super(Schroedinger_NO_2d_Decompose_Diff, self).__init__()

        self.padding = 9 # pad the domain if input is non-periodic
        self.n_hidden = n_hidden

        self.lift_layer = nn.Linear(3, n_hidden, bias=False)

        self.blocks = nn.ModuleList([Schroedinger_Evolution_Layer_2d_Decompose(spatial_modes_1,
                                                                     spatial_modes_2,
                                                                     n_hidden,
                                                                     last_layer=(_ == n_layers - 1))
                                                                     for _ in range(n_layers)])

        self.proj_layer = Recover_to_Sol_Space(n_hidden, 1, n_hidden * 2, n_hidden, bias=False)  # output channel_dim is 1: u1(x)

        self.diff_layer_1 = Spatial_Positional_Evolution_2d_High_Freqency(n_hidden, n_hidden)
        self.res_layer_1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.res_layer_2 = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear) and not torch.is_complex(m.weight):
            trunc_normal_(m.weight, std=0.219)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):

        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.lift_layer(x)
        x = x.permute(0, 3, 1, 2)
        x = F.pad(x, [0,self.padding, 0,self.padding])

        for block in self.blocks:
            x = block(x)

        x = x + self.res_layer_1(x) + F.tanh(x + self.diff_layer_1(x, 0.5/x.shape[-1]) + self.res_layer_2(x))

        x = x[..., :-self.padding, :-self.padding]
        x = self.proj_layer(x)
        # x = x[:, self.n_hidden//2:self.n_hidden//2 + 1, :, :].permute(0, 2, 3, 1)  # delta
        # x = torch.sum(x[:, :self.n_hidden//2, :, :], dim = 1, keepdim=True).permute(0, 2, 3, 1)  # step
        # x = torch.mean(x, dim = 1, keepdim=True).permute(0, 2, 3, 1)  # mean

        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)