import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import linspace
from timm.models.layers import trunc_normal_

import operator
from functools import reduce

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, 
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c

def activation_function(name):
    if name == 'relu':
        return nn.ReLU()
    elif name == 'gelu':
        return nn.GELU()
    elif name == 'tanh':
        return nn.Tanh()
    elif name == 'silu':
        return nn.SiLU()
    else:
        print("You should add the activation function in the utils.")


class Complex_Activation(nn.Module):
    def __init__(self, act = 'gelu', use_phase=False):
        super(Complex_Activation, self).__init__()
        self.act = activation_function(act)
        self.use_phase = use_phase

        self.b = nn.Parameter(torch.tensor([0.1]))
        self.b.requiresGrad = True

    def forward(self, z):
        if self.use_phase:
            return self.act(torch.abs(z) + self.b) * torch.exp(1.j * torch.angle(z)) 
        else:
            return self.act(z.real) + 1.j * self.act(z.imag)


def complex_multiplication_1d_FC(input, weights):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    return torch.einsum("bix,iox->box", input, weights)


def complex_multiplication_1d_Diag(input, weights):
    # Perform element-wise multiplication using torch's broadcasting (Diag matrix-vector multiplication)
    return input * weights.unsqueeze(0)


def add_diag(tensor_FC, tensor_diag, n_hidden):

    for i in range(n_hidden):
        tensor_FC[i, i, :] += tensor_diag[i, :]

    return tensor_FC


class MLP(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False, activation = 'gelu'):
        super(MLP, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias)

        self.activation = activation_function(activation)

    def forward(self, x):

        x = x.permute(0, 2, 1)

        x = self.mlp1(x)
        x = self.activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 2, 1)

        return x
    
class MLP_Complex(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels, bias = False, activation = 'gelu'):
        super(MLP_Complex, self).__init__()

        self.mlp1 = nn.Linear(in_channels, mid_channels, bias=bias, dtype=torch.cfloat)
        self.complex_activation = Complex_Activation(activation)
        self.mlp2 = nn.Linear(mid_channels, out_channels, bias=bias, dtype=torch.cfloat)


    def forward(self, x):

        x = x.permute(0, 2, 1)

        x = self.mlp1(x)
        x = self.complex_activation(x)
        x = self.mlp2(x)

        x = x.permute(0, 2, 1)

        return x


################################################################
# Position/Momentum Layers
################################################################

class Bi_Momentum_Evolution_1d_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, module_list = ['FC', 'Diag']):
        super(Bi_Momentum_Evolution_1d_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1

        self.n_hidden = n_hidden//2 + 1

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat), self.n_hidden))

        else:

            if self.module_FC:

                self.scale_FC = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / self.n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat))

    def forward(self, x):

        batchsize = x.shape[0]

        x_ft = torch.fft.fft(x, dim = -1)

        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-1), dtype=torch.cfloat, device=x.device)

        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1] = complex_multiplication_1d_FC(x_ft[:, :, :self.modes_1], self.weights_1)
            out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_FC(x_ft[:, :, -self.modes_1:], self.weights_2)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1] = complex_multiplication_1d_Diag(x_ft[:, :, :self.modes_1], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_Diag(x_ft[:, :, -self.modes_1:], self.weights_2_diag)

            if self.module_FC:

                out_ft[:, :, :self.modes_1] = complex_multiplication_1d_FC(x_ft[:, :, :self.modes_1], self.weights_1)
                out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_FC(x_ft[:, :, -self.modes_1:], self.weights_2)

        x = torch.fft.ifft(out_ft, n=x.size(-1))

        return x


class Hidden_Position_Space_Momentum_1d_Evolution_Spatial_Low_Freqency(nn.Module):

    def __init__(self, n_hidden, modes_1, module_list = ['FC', 'Diag']):
        super(Hidden_Position_Space_Momentum_1d_Evolution_Spatial_Low_Freqency, self).__init__()

        self.modes_1 = modes_1

        self.n_hidden = n_hidden

        self.module_FC = ('FC' in module_list)
        self.module_Diag = ('Diag' in module_list)


        if self.module_FC and self.module_Diag:

                self.scale = 1 / (self.n_hidden * self.n_hidden)

                self.weights_1 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat), self.n_hidden))
                self.weights_2 = nn.Parameter(self.scale * add_diag(torch.rand(self.n_hidden, self.n_hidden, self.modes_1, dtype=torch.cfloat), torch.rand(self.n_hidden, self.modes_1, dtype=torch.cfloat), self.n_hidden))

        else:
            if self.module_FC:

                self.scale_FC = 1 / (n_hidden * n_hidden)

                self.weights_1 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, dtype=torch.cfloat))
                self.weights_2 = nn.Parameter(self.scale_FC * torch.rand(n_hidden, n_hidden, self.modes_1, dtype=torch.cfloat))

            if self.module_Diag:

                self.scale_Diag = 1 / n_hidden

                self.weights_1_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, dtype=torch.cfloat))
                self.weights_2_diag = nn.Parameter(self.scale_Diag * torch.rand(n_hidden, self.modes_1, dtype=torch.cfloat))

    def forward(self, x):
        batchsize = x.shape[0]

        x_ft = torch.fft.rfft(x)

        out_ft = torch.zeros(batchsize, self.n_hidden, x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)


        if self.module_FC and self.module_Diag:

            out_ft[:, :, :self.modes_1] = complex_multiplication_1d_FC(x_ft[:, :, :self.modes_1], self.weights_1)
            out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_FC(x_ft[:, :, -self.modes_1:], self.weights_2)

        else:
            
            if self.module_Diag:

                out_ft[:, :, :self.modes_1] = complex_multiplication_1d_Diag(x_ft[:, :, :self.modes_1], self.weights_1_diag)
                out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_Diag(x_ft[:, :, -self.modes_1:], self.weights_2_diag)

            if self.module_FC:
                            
                out_ft[:, :, :self.modes_1] = complex_multiplication_1d_FC(x_ft[:, :, :self.modes_1], self.weights_1)
                out_ft[:, :, -self.modes_1:] = complex_multiplication_1d_FC(x_ft[:, :, -self.modes_1:], self.weights_2)

        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x



class Recover_to_Sol_Space(nn.Module):

    '''
    Performing Integration back to the 1D dimension
    '''

    def __init__(self, in_channels, out_channels, mid_channels1, bias = False):

        super(Recover_to_Sol_Space, self).__init__()

        self.linear_layer_0 = nn.Linear(in_channels, mid_channels1, bias=bias)
        self.linear_layer_1 = nn.Linear(mid_channels1, out_channels, bias=bias)   

    def forward(self, x):

        x = x.permute(0, 2, 1)

        x = self.linear_layer_0(x)
        x = F.gelu(x)
        x = self.linear_layer_1(x)

        return x


class Schroedinger_Evolution_Layer_1d_Decompose(nn.Module):

    def __init__(self, spatial_modes_1, n_hidden, last_layer = False):
        
        super(Schroedinger_Evolution_Layer_1d_Decompose, self).__init__()

        self.n_hidden = n_hidden
        self.last_layer = last_layer

        self.hidden_momentum_spatial_momentum_evolution_layer = \
            Bi_Momentum_Evolution_1d_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, module_list=['FC', 'Diag'])
   
        self.hidden_position_spatial_momentum_evolution_layer = \
            Hidden_Position_Space_Momentum_1d_Evolution_Spatial_Low_Freqency(self.n_hidden, spatial_modes_1, module_list=['Diag', 'FC'])

        self.complex_mlp_residual = MLP_Complex(n_hidden // 2 + 1, n_hidden // 2 + 1, (n_hidden // 2 + 1) * 2, bias=False)

        self.mlp1 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp2 = MLP(n_hidden, n_hidden, n_hidden, bias=False)
        self.mlp_res = MLP(n_hidden, n_hidden, n_hidden, bias=False)

        self.residual_layer = MLP(n_hidden, n_hidden, 2 * n_hidden, bias=False)

    def forward(self, x):

        x_res = self.mlp_res(x)

        x_ft_on_p = torch.fft.rfft(x, dim = -2)
        x1 = self.hidden_momentum_spatial_momentum_evolution_layer(x_ft_on_p)
        x2 = self.complex_mlp_residual(x_ft_on_p)
        x_ft_on_p = x1 + x2
        x12 = torch.fft.irfft(x_ft_on_p, dim = -2, n = self.n_hidden)

        x12 = self.mlp1(x12)

        x3 = self.hidden_position_spatial_momentum_evolution_layer(x)

        x4 = self.residual_layer(x)

        x = x + x12 + x3 + x4

        if not self.last_layer:

            x = F.gelu(x)

        return x + x_res


class Schroedinger_NO_1d_Decompose(nn.Module):

    def __init__(self, spatial_modes_1, n_hidden, n_layers = 4):

        super(Schroedinger_NO_1d_Decompose, self).__init__()

        self.n_hidden = n_hidden

        self.lift_layer = nn.Linear(2, n_hidden, bias=False)

        self.blocks = nn.ModuleList([Schroedinger_Evolution_Layer_1d_Decompose(spatial_modes_1,
                                                                     n_hidden,
                                                                     last_layer=(_ == n_layers - 1))
                                                                     for _ in range(n_layers)])

        self.proj_layer = Recover_to_Sol_Space(n_hidden, 1, n_hidden * 2, bias=False)


    def forward(self, x):

        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.lift_layer(x)
        x = x.permute(0, 2, 1)

        for block in self.blocks:
            x = block(x)

        x = self.proj_layer(x)
        # x = x[:, self.n_hidden//2:self.n_hidden//2 + 1, :, :].permute(0, 2, 3, 1)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)