import torch
import torch.nn as nn
import torch.nn.functional as F

from numpy import linspace
from timm.models.layers import trunc_normal_

import operator
from functools import reduce


####### Functions #######

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, 
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c


def activation_function(name):
    if name == 'relu':
        return torch.nn.ReLU()
    elif name == 'gelu':
        return torch.nn.GELU()
    elif name == 'tanh':
        return torch.nn.Tanh()
    else:
        print("You should add the activation function in the utils.")


class Complex_Activation(nn.Module):
    def __init__(self, act = 'relu', use_phase=False):
        super(Complex_Activation, self).__init__()
        self.act = activation_function(act)
        self.use_phase = use_phase

        self.b = nn.Parameter(torch.tensor([0.1]))
        self.b.requiresGrad = True

    def forward(self, z):
        if self.use_phase:
            return self.act(torch.abs(z) + self.b) * torch.exp(1.j * torch.angle(z)) 
        else:
            return self.act(z.real) + 1.j * self.act(z.imag)


def complex_multiplication_2d_FC(input, weights):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    return torch.einsum("bixy,ioxy->boxy", input, weights)


def complex_multiplication_2d_Diag(input, weights):
    # Perform element-wise multiplication using torch's broadcasting (Diag matrix-vector multiplication)
    return input * weights.unsqueeze(0)


def add_diag(tensor_FC, tensor_diag, n_hidden):

    for i in range(n_hidden):
        tensor_FC[i, i, :, :] += tensor_diag[i, :, :]

    return tensor_FC


class MLP(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias):
        super(MLP, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias)

    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)

        x = x.permute(0, 3, 1, 2)

        return x
    
class MLP_Complex(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False):
        super(MLP_Complex, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias, dtype=torch.cfloat)
        self.complex_activation = Complex_Activation('gelu')
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias, dtype=torch.cfloat)


    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.mlp1(x)
        x = self.complex_activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 3, 1, 2)

        return x
    

################################################################
# Position/Momentum Layers
################################################################

class Bi_Momentum_Evolution_2d_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, modes_2, module_list = ['FC', 'Diag']):
        super(Bi_Momentum_Evolution_2d_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1
        self.modes_2 = modes_2

        self.n_hidden = n_hidden//2 + 1

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_3 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_4 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))

        else:

            if self.module_FC:

                self.scale_FC = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_3 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_4 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / self.n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_3_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_4_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

    def forward(self, x):

        batchsize = x.shape[0]

        x_ft = torch.fft.fft2(x, dim = [-2, -1])

        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-2), x.size(-1), dtype=torch.cfloat, device=x.device)

        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
            out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2)
            out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3)
            out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2_diag)
                out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4_diag)

            if self.module_FC:
                            
                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
                out_ft[:, :, -self.modes_1:, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, -self.modes_2:], self.weights_2)
                out_ft[:, :, :self.modes_1, -self.modes_2:] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, -self.modes_2:], self.weights_3)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_4)

        x = torch.fft.ifft2(out_ft, s=(x.size(-2), x.size(-1)))

        return x


class Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, modes_2, module_list = ['FC', 'Diag']):
        super(Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes_2 = modes_2

        self.n_hidden = n_hidden

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat), self.n_hidden))

        else:
            if self.module_FC:

                self.scale_FC = 1 / (n_hidden * n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, self.modes_2, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]

        x_ft = torch.fft.rfft2(x)

        out_ft = torch.zeros(batchsize, self.n_hidden,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)


        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
            out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_Diag(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2_diag)

            if self.module_FC:
                            
                out_ft[:, :, :self.modes_1, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, :self.modes_1, :self.modes_2], self.weights_1)
                out_ft[:, :, -self.modes_1:, :self.modes_2] = complex_multiplication_2d_FC(x_ft[:, :, -self.modes_1:, :self.modes_2], self.weights_2)

        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class Recover_to_Sol_Space(nn.Module):

    '''
    Performing Integration back to the 2D dimension
    '''

    def __init__(self, in_channels, out_channels, mid_channels1, mid_channels2, bias = False):

        super(Recover_to_Sol_Space, self).__init__()

        self.linear_layer_0 = nn.Linear(in_channels, mid_channels1, bias=bias)
        self.linear_layer_1 = nn.Linear(mid_channels1, mid_channels2, bias=bias)
        self.linear_layer_2 = nn.Linear(mid_channels2, out_channels, bias=bias)         

    def forward(self, x):

        x = x.permute(0, 2, 3, 1)

        x = self.linear_layer_0(x)
        x = F.gelu(x)
        x = self.linear_layer_1(x)
        x = F.gelu(x)
        x = self.linear_layer_2(x)

        return x


class Schroedinger_Evolution_Layer_2d_Decompose(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, n_hidden, last_layer = False):
        
        super(Schroedinger_Evolution_Layer_2d_Decompose, self).__init__()

        self.n_hidden = n_hidden
        self.last_layer = last_layer

        self.hidden_momentum_spatial_momentum_evolution_layer = \
            Bi_Momentum_Evolution_2d_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, spatial_modes_2, module_list=['FC', 'Diag'])
   
        self.hidden_position_spatial_momentum_evolution_layer = \
            Hidden_Position_Space_Momentum_2d_Evolution_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, spatial_modes_2, module_list=['FC', 'Diag'])

        self.complex_mlp = MLP_Complex(n_hidden // 2 + 1, n_hidden // 2 + 1, (n_hidden // 2 + 1) * 2, bias=False)

        self.mlp1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp2 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp3 = MLP(n_hidden, n_hidden, n_hidden, bias=False)

        self.residual_layer = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

    def forward(self, x):

        x_res = self.mlp3(x)

        x_ft_on_p = torch.fft.rfft(x, dim = -3)
        x1 = self.hidden_momentum_spatial_momentum_evolution_layer(x_ft_on_p)
        x2 = self.complex_mlp(x_ft_on_p)
        x_ft_on_p = x1 + x2
        x12 = torch.fft.irfft(x_ft_on_p, dim = -3, n = self.n_hidden)

        x12 = self.mlp1(x12)

        x3 = self.hidden_position_spatial_momentum_evolution_layer(x)

        x4 = self.residual_layer(x)

        x5 = x + x12 + x3 + x4

        if not self.last_layer:

            x5 = self.mlp2(x5)

        x5 = x5 + x_res

        return x + x5

class Schroedinger_NO_2d_time(nn.Module):

    def __init__(self, spatial_modes_1, spatial_modes_2, n_hidden, n_layers = 4):

        super(Schroedinger_NO_2d_time, self).__init__()

        self.lift_layer = nn.Linear(22, n_hidden, bias=False)

        self.blocks = nn.ModuleList([Schroedinger_Evolution_Layer_2d_Decompose(spatial_modes_1,
                                                                     spatial_modes_2,
                                                                     n_hidden,
                                                                     last_layer=(_ == n_layers - 1))
                                                                     for _ in range(n_layers)])

        self.proj_layer = Recover_to_Sol_Space(n_hidden, 2, n_hidden * 4, n_hidden * 2, bias=False)
        self.norm = nn.InstanceNorm2d(n_hidden)
        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear) and not torch.is_complex(m.weight):
            # For real-valued weights, apply trunc_normal_ initialization
            trunc_normal_(m.weight, std=0.21)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.lift_layer(x)
        x = x.permute(0, 3, 1, 2)

        for block in self.blocks:
            x = self.norm(block(self.norm(x)))

        x = self.proj_layer(x)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
