import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
from math import pi, sqrt
from kan import KAN


class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, n1, dropout):
        super(SpectralConv1d, self).__init__()


        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n1 = n1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.kanxr = KAN([1,3,in_channels*out_channels], base_activation=nn.Identity, grid_size = 24)
        self.kanxi = KAN([1,3,in_channels*out_channels], base_activation=nn.Identity, grid_size = 24)     
        
    # Complex multiplication
    def compl_mul1d(self, input, weights):
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x, Tx):
        B, C, H = x.shape


        xr, xi = self.kanxr(Tx), self.kanxi(Tx)
        #xr, xi = self.kanxr(Tx), self.kanxi(Tx)

        Re = xr.permute(1,0).reshape(self.in_channels, self.out_channels, H//2+1)
        Im = xi.permute(1,0).reshape(self.in_channels, self.out_channels, H//2+1)
        kernel = ( Re + 1j*Im )

        out_ft = torch.zeros(B, C, H//2+1)
        x_ft = torch.fft.rfft(x)
        out_ft = self.compl_mul1d(x_ft, kernel)

        x = torch.fft.irfft(out_ft, n=H)      

        return x 

class MLP1d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, dropout = 0.):
        super(MLP1d, self).__init__()
        self.linear1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.linear2 = nn.Conv1d(mid_channels, out_channels, 1)

        self.act = nn.GELU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)

        return x

class FNO1d(nn.Module):
    def __init__(self, width, n1 = 10, padding = 0, input_dim = 1, output_dim = 1, mlp_dropout = 0, H = 256):
        super(FNO1d, self).__init__()



        self.width = width
        self.padding = padding # pad the domain if input is non-periodic

        self.p = nn.Linear(input_dim+1, self.width) # input channel is 3: (a(x, y), x, y)
        self.conv0 = SpectralConv1d(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv1 = SpectralConv1d(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv2 = SpectralConv1d(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv3 = SpectralConv1d(self.width, self.width, n1, dropout = mlp_dropout)
        self.mlp0 = MLP1d(self.width, self.width, self.width)
        self.mlp1 = MLP1d(self.width, self.width, self.width)
        self.mlp2 = MLP1d(self.width, self.width, self.width)
        self.mlp3 = MLP1d(self.width, self.width, self.width)

        self.q = MLP1d(self.width, output_dim, self.width * 4) 

        self.Tx = torch.fft.rfftfreq(H+padding).reshape(-1,1).cuda() 

    def forward(self, x, grid = None):
        if grid is None:
            grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 2, 1)
        
        x = F.pad(x, [0,self.padding])

        B, C, H = x.shape

        #print(Tx)
        x1 = self.conv0(x, self.Tx)
        x1 = self.mlp0(x1)
        x = x1 + x
        x = F.gelu(x)
        
        x1 = self.conv1(x, self.Tx)
        x1 = self.mlp1(x1)        
        x = x1 + x
        x = F.gelu(x)

        x1 = self.conv2(x, self.Tx)
        x1 = self.mlp2(x1)
        x = x1 + x
        x = F.gelu(x)

        x1 = self.conv3(x, self.Tx)
        x1 = self.mlp3(x1)
        x = x1 + x
        #x = F.gelu(x)

        if self.padding > 0:
            x = x[..., :-self.padding]
        x = self.q(x)
        x = x.permute(0, 2, 1)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)


class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels,dropout = 0.):
        super(MLP, self).__init__()
        self.linear1 = nn.Conv2d(in_channels, mid_channels, 1)
        self.linear2 = nn.Conv2d(mid_channels, out_channels, 1)
        self.dropout = dropout

        self.act = nn.GELU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = F.dropout(x, p =self.dropout)
        x = self.linear2(x)


        return x

    
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, n1, n2, dropout, padding, modes1, modes2):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n1 = n1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.n2 = n2
        self.dropout = dropout

        
        self.modes1 = modes1
        self.modes2 = modes2

        self.kanr = KAN([2,5,in_channels*out_channels], base_activation=nn.Identity, grid_size = 32)
        self.kani = KAN([2,5,in_channels*out_channels], base_activation=nn.Identity, grid_size = 32)
                    
        self.padding = padding


    # Complex multiplication
    def compl_mul2d(self, input, weights):

        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, Txy):
        B, H, W = x.shape[0], x.shape[-2],x.shape[-1]

        kernel_ft = (self.kanr(Txy) + 1j*self.kani(Txy)).permute(1,0).reshape(self.in_channels, self.out_channels, self.modes1, self.modes2//2+1)

        x_ft = torch.fft.rfft2(x)
        #ut_ft = torch.zeros(B, self.out_channels,  H, W//2 + 1, dtype=torch.cfloat, device=x.device)
        #print(out_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1].shape,out_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1].shape )
        #out_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1] = \
            #self.compl_mul2d(x_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1], kernel_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1])
        #out_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1] = \
            #self.compl_mul2d(x_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1], kernel_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1])
        out_ft = self.compl_mul2d(x_ft, kernel_ft)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s = (H, W))
        return x


class FNO2d(nn.Module):
    def __init__(self, width, n1 = 10, n2 = 10, padding = 0, input_dim = 1, output_dim = 1, mlp_dropout = 0, H = 85, W = 85):
        super(FNO2d, self).__init__()



        self.width = width
        self.padding = padding # pad the domain if input is non-periodic

        self.p = nn.Linear(input_dim+2, self.width) # input channel is 3: (a(x, y), x, y)
        self.conv0 = SpectralConv2d(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv1 = SpectralConv2d(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv2 = SpectralConv2d(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv3 = SpectralConv2d(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.mlp0 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp1 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp2 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp3 = MLP(self.width, self.width, 4*self.width,dropout=0.)

        self.q = MLP(self.width, output_dim, 4*self.width) # output channel is 1: u(x, y)   
        self.n1 = n1
        self.n2 = n2
        self.grade1 = torch.linspace(1, self.n1, self.n1).view(-1, 1).float()
        self.grade2 = torch.linspace(1, self.n2, self.n2).view(-1, 1).float()
        self.gridx = torch.fft.fftfreq(H+padding)  #wx
        self.gridy = torch.fft.rfftfreq(W+padding) #wy
        self.Txy = torch.cartesian_prod(self.gridx, self.gridy).cuda()

        
    def forward(self, x ,grid = None):
        if grid is None:
            grid = self.get_grid(x.shape, x.device)  
        
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 3, 1, 2)
        
        x = F.pad(x, [0,self.padding, 0,self.padding])
        
        B, C, H, W = x.shape

        
        x1 = self.conv0(x, self.Txy)
        x1 = self.mlp0(x1)
        #x2 = self.w0(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv1(x, self.Txy)
        x1 = self.mlp1(x1)
        #x2 = self.w0(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv2(x, self.Txy)
        x1 = self.mlp2(x1)
        #x2 = self.w0(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv3(x, self.Txy)
        x1 = self.mlp3(x1)
        #x2 = self.w0(x)
        x = x + x1
        #x = F.gelu(x)

        if self.padding > 0:
            x = x[..., :-self.padding, :-self.padding]

        x = self.q(x)
        x = x.permute(0, 2, 3, 1)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)


class SpectralConv2dMLP(nn.Module):
    def __init__(self, in_channels, out_channels, n1, n2, dropout, padding, modes1, modes2):
        super(SpectralConv2dMLP, self).__init__()


        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n1 = n1 
        self.n2 = n2
        self.dropout = dropout
        self.mlpxr = MLP(n1, in_channels*out_channels, 2*n1, dropout = dropout)
        self.mlpxi = MLP(n1, in_channels*out_channels, 2*n1, dropout = dropout)
        self.mlpyr = MLP(n2, in_channels*out_channels, 2*n2, dropout = dropout)
        self.mlpyi = MLP(n2, in_channels*out_channels, 2*n2, dropout = dropout)
        self.modes1 = modes1
        self.modes2 = modes2        
        self.padding = padding


    def compl_mul2d(self, input, weights):

        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, Tx, Ty):
        B, H, W = x.shape[0], x.shape[-2],x.shape[-1]

        kernelx = (self.mlpxi(Tx)*1j + self.mlpxr(Tx)).reshape(self.in_channels, self.out_channels, H, 1 )
        kernely = (self.mlpyi(Ty)*1j + self.mlpyr(Ty)).reshape(self.in_channels, self.out_channels, 1, W//2+1 )
        kernel_ft = kernelx@kernely

       # kernel_ft = (self.mlpxr(Txy) + 1j*self.mlpxi(Txy)).reshape(self.in_channels, self.out_channels, H, W//2+1 )

        x_ft = torch.fft.rfft2(x)
        #out_ft = torch.zeros(B, self.out_channels,  H, W//2 + 1, dtype=torch.cfloat, device=x.device)
        #out_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1] = \
            #self.compl_mul2d(x_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1], kernel_ft[:, :, :(self.modes1+1)//2, :self.modes2//2+1])
        #out_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1] = \
            #self.compl_mul2d(x_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1], kernel_ft[:, :, -(self.modes1)//2:, :self.modes2//2+1])
        out_ft = self.compl_mul2d(x_ft, kernel_ft)

        x = torch.fft.irfft2(out_ft, s = (H, W))
        return x

    

class FNO2dMLP(nn.Module):
    def __init__(self, width, n1 = 10, n2 = 10, padding = 0, input_dim = 1, output_dim = 1, mlp_dropout = 0, H = 85, W = 85):
        super(FNO2dMLP, self).__init__()



        self.width = width
        self.padding = padding 

        self.p = nn.Linear(input_dim+2, self.width) # input channel is 3: (a(x, y), x, y)
        self.conv0 = SpectralConv2dMLP(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv1 = SpectralConv2dMLP(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv2 = SpectralConv2dMLP(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.conv3 = SpectralConv2dMLP(self.width, self.width, n1, n2, dropout = mlp_dropout, padding = padding, modes1 = H + padding, modes2 = (W + padding))
        self.mlp0 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp1 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp2 = MLP(self.width, self.width, 4*self.width,dropout=0.)
        self.mlp3 = MLP(self.width, self.width, 4*self.width,dropout=0.)


        self.q = MLP(self.width, output_dim, 4*self.width) 
        self.n1 = n1
        self.n2 = n2
        self.grade1 = torch.linspace(1, self.n1, self.n1).view(-1, 1).float()
        self.grade2 = torch.linspace(1, self.n2, self.n2).view(-1, 1).float()
        self.gridx = torch.fft.fftfreq(H+padding).unsqueeze(0)   
        self.gridy = torch.fft.rfftfreq(W+padding).unsqueeze(0) 
        self.Tx = torch.zeros(self.n1, H+padding)
        self.Ty = torch.zeros(self.n2, (W+padding)//2+1)
        self.Tx = (torch.cos(self.grade1@torch.acos(self.gridx))).reshape(1, self.n1, H+padding, 1).cuda() 
        self.Ty = (torch.cos(self.grade2@torch.acos(self.gridy))).reshape(1, self.n2, 1, (W+padding)//2+1).cuda() 

        
    def forward(self, x ,grid = None):
        if grid is None:
            grid = self.get_grid(x.shape, x.device)  
        
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 3, 1, 2)
        
        x = F.pad(x, [0,self.padding, 0,self.padding])
        
        B, C, H, W = x.shape

        
        x1 = self.conv0(x, self.Tx, self.Ty)
        x1 = self.mlp0(x1)
        #x2 = self.w0(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv1(x, self.Tx, self.Ty)
        x1 = self.mlp1(x1)
        #x2 = self.w1(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv2(x, self.Tx, self.Ty)
        x1 = self.mlp2(x1)
        #x2 = self.w2(x)
        x = x + x1
        x = F.gelu(x)

        x1 = self.conv3(x, self.Tx, self.Ty)
        x1 = self.mlp3(x1)
        #x2 = self.w3(x)
        x = x + x1
        #x = F.gelu(x)

        if self.padding > 0:
            x = x[..., :-self.padding, :-self.padding]

        x = self.q(x)
        x = x.permute(0, 2, 3, 1)
        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)




class SpectralConv1dMLP(nn.Module):
    def __init__(self, in_channels, out_channels, n1, dropout):
        super(SpectralConv1dMLP, self).__init__()


        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n1 = n1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.mlpxr = MLP1d(n1, in_channels*out_channels, 2*n1, dropout = dropout)
        self.mlpxi = MLP1d(n1, in_channels*out_channels, 2*n1, dropout = dropout)
        #self.mlpxr = nn.Conv1d(n1, in_channels*out_channels, 1)
        #self.mlpxi = nn.Conv1d(n1, in_channels*out_channels, 1)
        
        self.dropout = dropout
    # Complex multiplication
    def compl_mul1d(self, input, weights):
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x, Tx):
        B, C, H = x.shape


        xr, xi = self.mlpxr(Tx), self.mlpxi(Tx)
        Re = xr.reshape(self.in_channels, self.out_channels, H//2+1)
        Im = xi.reshape(self.in_channels, self.out_channels, H//2+1)
        kernel = ( Re + 1j*Im )

        out_ft = torch.zeros(B, C, H//2+1)
        x_ft = torch.fft.rfft(x)
        out_ft = self.compl_mul1d(x_ft, kernel)

        x = torch.fft.irfft(out_ft, n=H)      

        return x 

class FNO1dMLP(nn.Module):
    def __init__(self, width, n1 = 10, padding = 0, input_dim = 1, output_dim = 1, mlp_dropout = 0., H =256):
        super(FNO1dMLP, self).__init__()



        self.width = width
        self.padding = padding # pad the domain if input is non-periodic

        self.p = nn.Linear(input_dim+1, self.width) # input channel is 3: (a(x, y), x, y)
        self.conv0 = SpectralConv1dMLP(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv1 = SpectralConv1dMLP(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv2 = SpectralConv1dMLP(self.width, self.width, n1, dropout = mlp_dropout)
        self.conv3 = SpectralConv1dMLP(self.width, self.width, n1, dropout = mlp_dropout)
        self.mlp0 = MLP1d(self.width, self.width, 4*self.width)
        self.mlp1 = MLP1d(self.width, self.width, 4*self.width)
        self.mlp2 = MLP1d(self.width, self.width, 4*self.width)
        self.mlp3 = MLP1d(self.width, self.width, 4*self.width)


        self.n1 = n1
        self.grade1 = torch.arange(1, self.n1+1).reshape(self.n1,1).float()
        self.gridx = torch.fft.rfftfreq(H+padding).unsqueeze(0)   #wx
 
        self.Tx = torch.zeros(self.n1, (H+padding)//2+1)
  
        self.Tx = (torch.cos(self.grade1@torch.acos(self.gridx))).reshape(1, self.n1, (H+padding)//2+1).cuda() 


    def forward(self, x, grid = None):
        if grid is None :
            grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 2, 1)
        
        x = F.pad(x, [0,self.padding])

        B, C, H = x.shape

        #Tx = (torch.cos(self.grade1@torch.acos(gridx))).reshape(1, self.n1, H//2+1) 

        x1 = self.conv0(x, self.Tx)
        x1 = self.mlp0(x1)
        #x2 = self.w0(x)
        x = x1 + x
        x = F.gelu(x)

        x1 = self.conv1(x, self.Tx)
        x1 = self.mlp1(x1)
        #x2 = self.w0(x)
        x = x1 + x
        x = F.gelu(x)

        x1 = self.conv2(x, self.Tx)
        x1 = self.mlp2(x1)
        #x2 = self.w0(x)
        x = x1 + x
        x = F.gelu(x)

        x1 = self.conv3(x, self.Tx)
        x1 = self.mlp3(x1)
        #x2 = self.w0(x)
        x = x1 + x
        #x = F.gelu(x)

        if self.padding > 0:
            x = x[..., :-self.padding]
            
        x = self.q(x)
        x = x.permute(0, 2, 1)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
