import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, spatial_length):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(
            in_channels, out_channels, (self.modes1 + spatial_length)//2 + 1, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x, n=self.modes1 + x.size(-1))
        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize, self.out_channels, 
            x.size(-1),  
            device=x.device, dtype=torch.cfloat)
        # out_ft[:, :, :self.modes1] = self.compl_mul1d(
        #     x_ft[:, :, :self.modes1], self.weights1)
        out_ft = self.compl_mul1d(x_ft, self.weights1)


        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1) + self.modes1)[..., :x.size(-1)]
        return x



class MLP1d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP1d, self).__init__()
        self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class FNO1d(nn.Module):
    def __init__(
            self, 
            d_input, 
            d_output,
            *,
            spatial_length,
            n_layers=4,
            modes=16, 
            d_model=64, 
            initial_step=10,
            **kwargs):
        super(FNO1d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """
        raise NotImplementedError("deprecated")
        self.initial_step = initial_step

        self.modes1 = modes
        self.d_model = d_model
        self.padding = 2 # pad the domain if input is non-periodic
        # input channel is 2: (a(x), x)
        self.p = nn.Linear(initial_step*d_input+1, self.d_model) 

        self.fno_layers = nn.ModuleList()
        self.ws = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for _ in range(n_layers-1):
            self.fno_layers.append(
                SpectralConv1d(self.d_model, self.d_model, self.modes1, spatial_length=spatial_length))
            self.ws.append(nn.Conv1d(self.d_model, self.d_model, 1))
            self.mlps.append(MLP1d(self.d_model, self.d_model, self.d_model))

        self.conv = SpectralConv1d(self.d_model, self.d_model, self.modes1, spatial_length=spatial_length)
        self.w = nn.Conv1d(self.d_model, self.d_model, 1)
        self.mlp = MLP1d(self.d_model, self.d_model, self.d_model)

        self.q = MLP1d(self.d_model, d_output, self.d_model * 2)
                

    def forward(self, x, grid):
        # x dim = [b, x1, t*v]
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 2, 1)
        
        # x = F.pad(x, [0, self.padding]) # pad the domain if input is non-periodic
        for layer, w, mlp in zip(self.fno_layers, self.ws, self.mlps):
            x1 = layer(x)
            x1 = mlp(x1)
            x2 = w(x)
            x = x1 + x2
            x = F.gelu(x)
        
        x1 = self.conv(x)
        x1 = self.mlp(x1)
        x2 = self.w(x)
        x = x1 + x2

        # x = x[..., :-self.padding]
        x = self.q(x)
        x = x.permute(0, 2, 1)
        return x